From 209e3f96fa83dee69bf415bae77136d265584570 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Mon, 24 Nov 2025 13:33:17 -0800 Subject: [PATCH 1/3] Add retrying of workflow validation errors only on python 3.10 --- .github/workflows/ci.yml | 2 ++ pyproject.toml | 1 + uv.lock | 15 +++++++++++++++ 3 files changed, 18 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e24187c5f..06e953692 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,6 +32,8 @@ jobs: - os: ubuntu-latest python: "3.10" protoCheckTarget: true + - python: "3.10" + pytestExtraArgs: "--reruns 3 --only-rerun \"RuntimeError: Failed validating workflow\"" - os: ubuntu-arm runsOn: ubuntu-24.04-arm64-2-core - os: macos-intel diff --git a/pyproject.toml b/pyproject.toml index 9832c8cf0..51f81e08f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ dev = [ "openai-agents>=0.3,<0.5; python_version >= '3.14'", "openai-agents[litellm]>=0.3,<0.4; python_version < '3.14'", "googleapis-common-protos==1.70.0", + "pytest-rerunfailures>=16.1", ] [tool.poe.tasks] diff --git a/uv.lock b/uv.lock index 8b1ee82a8..252bf4059 100644 --- a/uv.lock +++ b/uv.lock @@ -2394,6 +2394,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/85/2f97a1b65178b0f11c9c77c35417a4cc5b99a80db90dad4734a129844ea5/pytest_pretty-1.3.0-py3-none-any.whl", hash = "sha256:074b9d5783cef9571494543de07e768a4dda92a3e85118d6c7458c67297159b7", size = 5620, upload-time = "2025-06-04T12:54:36.229Z" }, ] +[[package]] +name = "pytest-rerunfailures" +version = "16.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/04/71e9520551fc8fe2cf5c1a1842e4e600265b0815f2016b7c27ec85688682/pytest_rerunfailures-16.1.tar.gz", hash = "sha256:c38b266db8a808953ebd71ac25c381cb1981a78ff9340a14bcb9f1b9bff1899e", size = 30889, upload-time = "2025-10-10T07:06:01.238Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/54/60eabb34445e3db3d3d874dc1dfa72751bfec3265bd611cb13c8b290adea/pytest_rerunfailures-16.1-py3-none-any.whl", hash = "sha256:5d11b12c0ca9a1665b5054052fcc1084f8deadd9328962745ef6b04e26382e86", size = 14093, upload-time = "2025-10-10T07:06:00.019Z" }, +] + [[package]] name = "pytest-timeout" version = "2.4.0" @@ -2997,6 +3010,7 @@ dev = [ { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-pretty" }, + { name = "pytest-rerunfailures" }, { name = "pytest-timeout" }, { name = "ruff" }, { name = "toml" }, @@ -3038,6 +3052,7 @@ dev = [ { name = "pytest-asyncio", specifier = ">=0.21,<0.22" }, { name = "pytest-cov", specifier = ">=6.1.1" }, { name = "pytest-pretty", specifier = ">=1.3.0" }, + { name = "pytest-rerunfailures", specifier = ">=16.1" }, { name = "pytest-timeout", specifier = "~=2.2" }, { name = "ruff", specifier = ">=0.5.0,<0.6" }, { name = "toml", specifier = ">=0.10.2,<0.11" }, From cf02c2b245b455a6b24be7efcf56a5f3d176e73b Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 25 Nov 2025 11:40:41 -0800 Subject: [PATCH 2/3] use session scoped shared state manager to ensure shutdown occurs. --- tests/conftest.py | 20 +- tests/worker/test_activity.py | 382 ++++++++++++++++++++++++++++------ 2 files changed, 330 insertions(+), 72 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c0f8bc5e0..bf0f6856e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,11 @@ import pytest import pytest_asyncio +from temporalio.client import Client +from temporalio.testing import WorkflowEnvironment +from tests.helpers.worker import ExternalPythonWorker, ExternalWorker +from temporalio.worker import SharedStateManager + from . import DEV_SERVER_DOWNLOAD_VERSION # If there is an integration test environment variable set, we must remove the @@ -38,10 +43,6 @@ or protobuf_version.startswith("6.") ), f"Expected protobuf 4.x/5.x/6.x, got {protobuf_version}" -from temporalio.client import Client -from temporalio.testing import WorkflowEnvironment -from tests.helpers.worker import ExternalPythonWorker, ExternalWorker - def pytest_runtest_setup(item): """Print a newline so that custom printed output starts on new line.""" @@ -134,6 +135,17 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: await env.shutdown() +@pytest.fixture(scope="session") +def shared_state_manager() -> Iterator[SharedStateManager]: + mp_mgr = multiprocessing.Manager() + mgr = SharedStateManager.create_from_multiprocessing(mp_mgr) + + try: + yield mgr + finally: + mp_mgr.shutdown() + + @pytest.fixture(scope="session") def mp_fork_ctx() -> Iterator[multiprocessing.context.BaseContext | None]: mp_ctx = None diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index 203b89a5a..ef040071a 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -2,7 +2,6 @@ import concurrent.futures import logging import logging.handlers -import multiprocessing import os import queue import signal @@ -19,6 +18,8 @@ import pytest +import temporalio.exceptions +import temporalio.api.common.v1 import temporalio.api.workflowservice.v1 from temporalio import activity, workflow from temporalio.client import ( @@ -52,34 +53,53 @@ kitchen_sink_retry_policy, ) -_default_shared_state_manager = SharedStateManager.create_from_multiprocessing( - multiprocessing.Manager() -) default_max_concurrent_activities = 50 -async def test_activity_hello(client: Client, worker: ExternalWorker): +async def test_activity_hello( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def say_hello(name: str) -> str: return f"Hello, {name}!" result = await _execute_workflow_with_activity( - client, worker, say_hello, "Temporal" + client, + worker, + say_hello, + "Temporal", + shared_state_manager=shared_state_manager, ) assert result.result == "Hello, Temporal!" -async def test_activity_without_decorator(client: Client, worker: ExternalWorker): +async def test_activity_without_decorator( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): async def say_hello(name: str) -> str: return f"Hello, {name}!" with pytest.raises(TypeError) as err: - await _execute_workflow_with_activity(client, worker, say_hello, "Temporal") + await _execute_workflow_with_activity( + client, + worker, + say_hello, + "Temporal", + shared_state_manager=shared_state_manager, + ) assert "Activity say_hello missing attributes" in str(err.value) -async def test_activity_custom_name(client: Client, worker: ExternalWorker): +async def test_activity_custom_name( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn(name="my custom activity name!") async def get_name(name: str) -> str: return f"Name: {activity.info().activity_type}" @@ -90,12 +110,15 @@ async def get_name(name: str) -> str: get_name, "Temporal", activity_name_override="my custom activity name!", + shared_state_manager=shared_state_manager, ) assert result.result == "Name: my custom activity name!" async def test_client_available_in_async_activities( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): with pytest.raises(RuntimeError, match="Not in activity context"): activity.client() @@ -107,12 +130,19 @@ async def capture_client() -> None: nonlocal captured_client captured_client = activity.client() - await _execute_workflow_with_activity(client, worker, capture_client) + await _execute_workflow_with_activity( + client, + worker, + capture_client, + shared_state_manager=shared_state_manager, + ) assert captured_client is client async def test_client_not_available_in_sync_activities( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): saw_error = False @@ -133,12 +163,16 @@ def some_activity() -> None: "activity_executor": concurrent.futures.ThreadPoolExecutor(1), "max_concurrent_activities": 1, }, + shared_state_manager=shared_state_manager, ) assert saw_error async def test_activity_info( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): # TODO(cretz): Fix if env.supports_time_skipping: @@ -160,7 +194,11 @@ async def capture_info() -> None: info = activity.info() result = await _execute_workflow_with_activity( - client, worker, capture_info, start_to_close_timeout_ms=4000 + client, + worker, + capture_info, + start_to_close_timeout_ms=4000, + shared_state_manager=shared_state_manager, ) assert info @@ -186,7 +224,11 @@ async def capture_info() -> None: assert info.retry_policy == kitchen_sink_retry_policy() -async def test_sync_activity_thread(client: Client, worker: ExternalWorker): +async def test_sync_activity_thread( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn def some_activity() -> str: return f"activity name: {activity.info().activity_type}" @@ -203,6 +245,7 @@ def some_activity() -> str: worker, some_activity, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "activity name: some_activity" @@ -212,7 +255,11 @@ def picklable_activity() -> str: return f"activity name: {activity.info().activity_type}" -async def test_sync_activity_process(client: Client, worker: ExternalWorker): +async def test_sync_activity_process( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): # We intentionally leave max_workers by default in the process pool executor # to confirm that the warning is triggered with concurrent.futures.ProcessPoolExecutor() as executor: @@ -225,12 +272,15 @@ async def test_sync_activity_process(client: Client, worker: ExternalWorker): worker, picklable_activity, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "activity name: picklable_activity" async def test_sync_activity_process_non_picklable( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn def some_activity() -> str: @@ -243,17 +293,27 @@ def some_activity() -> str: worker, some_activity, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert "must be picklable when using a process executor" in str(err.value) -async def test_activity_failure(client: Client, worker: ExternalWorker): +async def test_activity_failure( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def raise_error(): raise RuntimeError("oh no!") with pytest.raises(WorkflowFailureError) as err: - await _execute_workflow_with_activity(client, worker, raise_error) + await _execute_workflow_with_activity( + client, + worker, + raise_error, + shared_state_manager=shared_state_manager, + ) assert str(assert_activity_application_error(err.value)) == "RuntimeError: oh no!" @@ -262,7 +322,11 @@ def picklable_activity_failure(): raise RuntimeError("oh no!") -async def test_sync_activity_process_failure(client: Client, worker: ExternalWorker): +async def test_sync_activity_process_failure( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): with pytest.raises(WorkflowFailureError) as err: with concurrent.futures.ProcessPoolExecutor() as executor: await _execute_workflow_with_activity( @@ -270,17 +334,27 @@ async def test_sync_activity_process_failure(client: Client, worker: ExternalWor worker, picklable_activity_failure, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert str(assert_activity_application_error(err.value)) == "RuntimeError: oh no!" -async def test_activity_bad_params(client: Client, worker: ExternalWorker): +async def test_activity_bad_params( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def say_hello(name: str) -> str: return f"Hello, {name}!" with pytest.raises(WorkflowFailureError) as err: - await _execute_workflow_with_activity(client, worker, say_hello) + await _execute_workflow_with_activity( + client, + worker, + say_hello, + shared_state_manager=shared_state_manager, + ) assert str(assert_activity_application_error(err.value)).endswith( "missing 1 required positional argument: 'name'" ) @@ -296,7 +370,11 @@ async def say_hello(*, name: str) -> str: assert str(err.value).endswith("cannot have keyword-only arguments") -async def test_activity_cancel_catch(client: Client, worker: ExternalWorker): +async def test_activity_cancel_catch( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def wait_cancel() -> str: try: @@ -313,11 +391,16 @@ async def wait_cancel() -> str: cancel_after_ms=100, wait_for_cancellation=True, heartbeat_timeout_ms=2000, + shared_state_manager=shared_state_manager, ) assert result.result == "Got cancelled error, cancelled? True" -async def test_activity_cancel_throw(client: Client, worker: ExternalWorker): +async def test_activity_cancel_throw( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def wait_cancel() -> str: while True: @@ -332,13 +415,16 @@ async def wait_cancel() -> str: cancel_after_ms=100, wait_for_cancellation=True, heartbeat_timeout_ms=1000, + shared_state_manager=shared_state_manager, ) assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, CancelledError) async def test_sync_activity_thread_cancel_caught( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn def wait_cancel() -> str: @@ -361,12 +447,15 @@ def wait_cancel() -> str: wait_for_cancellation=True, heartbeat_timeout_ms=3000, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "Cancelled" async def test_sync_activity_thread_cancel_uncaught( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn def wait_cancel() -> NoReturn: @@ -386,13 +475,16 @@ def wait_cancel() -> NoReturn: wait_for_cancellation=True, heartbeat_timeout_ms=3000, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, CancelledError) async def test_sync_activity_thread_cancel_exception_disabled( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn(no_thread_cancel_exception=True) def wait_cancel() -> str: @@ -416,12 +508,15 @@ def wait_cancel() -> str: wait_for_cancellation=True, heartbeat_timeout_ms=3000, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "Cancelled" async def test_sync_activity_thread_cancel_exception_shielded( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): events: List[str] = [] @@ -451,6 +546,7 @@ def wait_cancel() -> None: wait_for_cancellation=True, heartbeat_timeout_ms=3000, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, CancelledError) @@ -530,7 +626,11 @@ def picklable_activity_wait_cancel() -> str: return "Cancelled" -async def test_sync_activity_process_cancel(client: Client, worker: ExternalWorker): +async def test_sync_activity_process_cancel( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): with concurrent.futures.ProcessPoolExecutor() as executor: result = await _execute_workflow_with_activity( client, @@ -540,6 +640,7 @@ async def test_sync_activity_process_cancel(client: Client, worker: ExternalWork wait_for_cancellation=True, heartbeat_timeout_ms=3000, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "Cancelled" @@ -553,7 +654,9 @@ def picklable_activity_raise_cancel() -> str: async def test_sync_activity_process_cancel_uncaught( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): with pytest.raises(WorkflowFailureError) as err: with concurrent.futures.ProcessPoolExecutor() as executor: @@ -565,6 +668,7 @@ async def test_sync_activity_process_cancel_uncaught( wait_for_cancellation=True, heartbeat_timeout_ms=5000, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, CancelledError) @@ -595,7 +699,11 @@ async def say_hello(name: str) -> str: assert "is not registered" in str(assert_activity_application_error(err.value)) -async def test_max_concurrent_activities(client: Client, worker: ExternalWorker): +async def test_max_concurrent_activities( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): seen_indexes: List[int] = [] complete_activities_event = asyncio.Event() @@ -619,6 +727,7 @@ async def some_activity(index: int) -> str: schedule_to_start_timeout_ms=1000, worker_config={"max_concurrent_activities": 42}, on_complete=complete_activities_event.set, + shared_state_manager=shared_state_manager, ) timeout = assert_activity_error(err.value) assert isinstance(timeout, TimeoutError) @@ -636,7 +745,11 @@ class SomeClass2: bar: Optional[SomeClass1] = None -async def test_activity_type_hints(client: Client, worker: ExternalWorker): +async def test_activity_type_hints( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): activity_param1: SomeClass2 @activity.defn @@ -651,6 +764,7 @@ async def some_activity(param1: SomeClass2, param2: str) -> str: some_activity, SomeClass2(foo="str1", bar=SomeClass1(foo=123)), "123", + shared_state_manager=shared_state_manager, ) assert ( result.result @@ -660,7 +774,10 @@ async def some_activity(param1: SomeClass2, param2: str) -> str: async def test_activity_heartbeat_details( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): if env.supports_time_skipping: pytest.skip("https://github.com/temporalio/sdk-java/issues/2459") @@ -681,6 +798,7 @@ async def some_activity() -> str: worker, some_activity, retry_max_attempts=4, + shared_state_manager=shared_state_manager, ) assert result.result == "final count: 36" @@ -690,7 +808,9 @@ class NotSerializableValue: async def test_activity_heartbeat_details_converter_fail( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn async def some_activity() -> str: @@ -702,14 +822,22 @@ async def some_activity() -> str: return "Should not get here" with pytest.raises(WorkflowFailureError) as err: - await _execute_workflow_with_activity(client, worker, some_activity) + await _execute_workflow_with_activity( + client, + worker, + some_activity, + shared_state_manager=shared_state_manager, + ) assert str(assert_activity_application_error(err.value)).endswith( "is not JSON serializable" ) async def test_activity_heartbeat_details_timeout( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): # TODO(cretz): Fix if env.supports_time_skipping: @@ -727,7 +855,11 @@ async def some_activity() -> str: # then check the timeout's details with pytest.raises(WorkflowFailureError) as err: await _execute_workflow_with_activity( - client, worker, some_activity, heartbeat_timeout_ms=1000 + client, + worker, + some_activity, + heartbeat_timeout_ms=1000, + shared_state_manager=shared_state_manager, ) timeout = assert_activity_error(err.value) assert isinstance(timeout, TimeoutError) @@ -751,7 +883,10 @@ def picklable_heartbeat_details_activity() -> str: async def test_sync_activity_thread_heartbeat_details( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): if env.supports_time_skipping: pytest.skip("https://github.com/temporalio/sdk-java/issues/2459") @@ -765,12 +900,16 @@ async def test_sync_activity_thread_heartbeat_details( picklable_heartbeat_details_activity, retry_max_attempts=2, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "attempt: 1, attempt: 2" async def test_sync_activity_process_heartbeat_details( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): if env.supports_time_skipping: pytest.skip("https://github.com/temporalio/sdk-java/issues/2459") @@ -782,6 +921,7 @@ async def test_sync_activity_process_heartbeat_details( picklable_heartbeat_details_activity, retry_max_attempts=2, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) assert result.result == "attempt: 1, attempt: 2" @@ -793,7 +933,9 @@ def picklable_activity_non_pickable_heartbeat_details() -> str: async def test_sync_activity_process_non_picklable_heartbeat_details( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): with pytest.raises(WorkflowFailureError) as err: with concurrent.futures.ProcessPoolExecutor() as executor: @@ -802,6 +944,7 @@ async def test_sync_activity_process_non_picklable_heartbeat_details( worker, picklable_activity_non_pickable_heartbeat_details, worker_config={"activity_executor": executor}, + shared_state_manager=shared_state_manager, ) msg = str(assert_activity_application_error(err.value)) # TODO: different messages can apparently be produced across runs/platforms @@ -813,7 +956,11 @@ async def test_sync_activity_process_non_picklable_heartbeat_details( ) -async def test_activity_error_non_retryable(client: Client, worker: ExternalWorker): +async def test_activity_error_non_retryable( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def some_activity(): if activity.info().attempt < 2: @@ -827,6 +974,7 @@ async def some_activity(): worker, some_activity, retry_max_attempts=100, + shared_state_manager=shared_state_manager, ) app_err = assert_activity_application_error(err.value) assert str(app_err) == "Do not retry me" @@ -834,7 +982,9 @@ async def some_activity(): async def test_activity_error_non_retryable_type( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn async def some_activity(): @@ -849,6 +999,7 @@ async def some_activity(): some_activity, retry_max_attempts=100, non_retryable_error_types=["Cannot retry me"], + shared_state_manager=shared_state_manager, ) assert ( str(assert_activity_application_error(err.value)) @@ -856,7 +1007,11 @@ async def some_activity(): ) -async def test_activity_logging(client: Client, worker: ExternalWorker): +async def test_activity_logging( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def say_hello(name: str) -> str: activity.logger.info(f"Called with arg: {name}") @@ -869,7 +1024,11 @@ async def say_hello(name: str) -> str: activity.logger.base_logger.setLevel(logging.INFO) try: result = await _execute_workflow_with_activity( - client, worker, say_hello, "Temporal" + client, + worker, + say_hello, + "Temporal", + shared_state_manager=shared_state_manager, ) finally: activity.logger.base_logger.removeHandler(handler) @@ -883,7 +1042,11 @@ async def say_hello(name: str) -> str: assert records[-1].__dict__["temporal_activity"]["activity_type"] == "say_hello" -async def test_activity_worker_shutdown(client: Client, worker: ExternalWorker): +async def test_activity_worker_shutdown( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): activity_started = asyncio.Event() @activity.defn @@ -971,7 +1134,9 @@ def picklable_wait_on_event() -> str: async def test_sync_activity_process_worker_shutdown_graceful( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): act_task_queue = str(uuid.uuid4()) with concurrent.futures.ProcessPoolExecutor() as executor: @@ -982,7 +1147,7 @@ async def test_sync_activity_process_worker_shutdown_graceful( activity_executor=executor, max_concurrent_activities=default_max_concurrent_activities, graceful_shutdown_timeout=timedelta(seconds=2), - shared_state_manager=_default_shared_state_manager, + shared_state_manager=shared_state_manager, ) asyncio.create_task(act_worker.run()) @@ -1028,7 +1193,9 @@ def kill_my_process() -> str: async def test_sync_activity_process_executor_crash( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): act_task_queue = str(uuid.uuid4()) with concurrent.futures.ProcessPoolExecutor() as executor: @@ -1039,7 +1206,7 @@ async def test_sync_activity_process_executor_crash( activity_executor=executor, max_concurrent_activities=default_max_concurrent_activities, graceful_shutdown_timeout=timedelta(seconds=2), - shared_state_manager=_default_shared_state_manager, + shared_state_manager=shared_state_manager, ) act_worker_task = asyncio.create_task(act_worker.run()) @@ -1102,12 +1269,20 @@ def async_handle(self, client: Client, use_task_token: bool) -> AsyncActivityHan @pytest.mark.parametrize("use_task_token", [True, False]) async def test_activity_async_success( - client: Client, worker: ExternalWorker, use_task_token: bool + client: Client, + worker: ExternalWorker, + use_task_token: bool, + shared_state_manager: SharedStateManager, ): # Start task, wait for info, complete with value, wait on workflow wrapper = AsyncActivityWrapper() task = asyncio.create_task( - _execute_workflow_with_activity(client, worker, wrapper.run) + _execute_workflow_with_activity( + client, + worker, + wrapper.run, + shared_state_manager=shared_state_manager, + ) ) await wrapper.wait_info() await wrapper.async_handle(client, use_task_token).complete("some value") @@ -1116,7 +1291,12 @@ async def test_activity_async_success( # Do again with a None value wrapper = AsyncActivityWrapper() task = asyncio.create_task( - _execute_workflow_with_activity(client, worker, wrapper.run) + _execute_workflow_with_activity( + client, + worker, + wrapper.run, + shared_state_manager=shared_state_manager, + ) ) await wrapper.wait_info() await wrapper.async_handle(client, use_task_token).complete(None) @@ -1129,6 +1309,7 @@ async def test_activity_async_heartbeat_and_fail( worker: ExternalWorker, env: WorkflowEnvironment, use_task_token: bool, + shared_state_manager: SharedStateManager, ): if env.supports_time_skipping: pytest.skip("https://github.com/temporalio/sdk-java/issues/2459") @@ -1137,7 +1318,11 @@ async def test_activity_async_heartbeat_and_fail( # Start task w/ max attempts 2, wait for info, send heartbeat, fail task = asyncio.create_task( _execute_workflow_with_activity( - client, worker, wrapper.run, retry_max_attempts=2 + client, + worker, + wrapper.run, + retry_max_attempts=2, + shared_state_manager=shared_state_manager, ) ) info = await wrapper.wait_info() @@ -1165,13 +1350,21 @@ async def test_activity_async_heartbeat_and_fail( @pytest.mark.parametrize("use_task_token", [True, False]) async def test_activity_async_cancel( - client: Client, worker: ExternalWorker, use_task_token: bool + client: Client, + worker: ExternalWorker, + use_task_token: bool, + shared_state_manager: SharedStateManager, ): wrapper = AsyncActivityWrapper() # Start task, wait for info, cancel, wait on workflow task = asyncio.create_task( _execute_workflow_with_activity( - client, worker, wrapper.run, cancel_after_ms=50, wait_for_cancellation=True + client, + worker, + wrapper.run, + cancel_after_ms=50, + wait_for_cancellation=True, + shared_state_manager=shared_state_manager, ) ) await wrapper.wait_info() @@ -1203,7 +1396,11 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: return await super().execute_activity(input) -async def test_sync_activity_contextvars(client: Client, worker: ExternalWorker): +async def test_sync_activity_contextvars( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn def some_activity() -> str: return f"context var: {some_context_var.get()}" @@ -1219,6 +1416,7 @@ def some_activity() -> str: "activity_executor": executor, "interceptors": [ContextVarInterceptor()], }, + shared_state_manager=shared_state_manager, ) assert result.result == "context var: some value!" @@ -1273,7 +1471,11 @@ def sync_dyn_activity(args: Sequence[RawValue]) -> DynActivityValue: ) -async def test_activity_dynamic(client: Client, worker: ExternalWorker): +async def test_activity_dynamic( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn(dynamic=True) async def async_dyn_activity(args: Sequence[RawValue]) -> DynActivityValue: return sync_dyn_activity(args) @@ -1286,11 +1488,16 @@ async def async_dyn_activity(args: Sequence[RawValue]) -> DynActivityValue: DynActivityValue("val2"), activity_name_override="some-activity-name", result_type_override=DynActivityValue, + shared_state_manager=shared_state_manager, ) assert result.result == DynActivityValue("some-activity-name - val1 - val2") -async def test_sync_activity_dynamic_thread(client: Client, worker: ExternalWorker): +async def test_sync_activity_dynamic_thread( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): with concurrent.futures.ThreadPoolExecutor( max_workers=default_max_concurrent_activities ) as executor: @@ -1303,11 +1510,16 @@ async def test_sync_activity_dynamic_thread(client: Client, worker: ExternalWork worker_config={"activity_executor": executor}, activity_name_override="some-activity-name", result_type_override=DynActivityValue, + shared_state_manager=shared_state_manager, ) assert result.result == DynActivityValue("some-activity-name - val1 - val2") -async def test_sync_activity_dynamic_process(client: Client, worker: ExternalWorker): +async def test_sync_activity_dynamic_process( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): with concurrent.futures.ProcessPoolExecutor() as executor: result = await _execute_workflow_with_activity( client, @@ -1318,11 +1530,16 @@ async def test_sync_activity_dynamic_process(client: Client, worker: ExternalWor worker_config={"activity_executor": executor}, activity_name_override="some-activity-name", result_type_override=DynActivityValue, + shared_state_manager=shared_state_manager, ) assert result.result == DynActivityValue("some-activity-name - val1 - val2") -async def test_activity_dynamic_duplicate(client: Client, worker: ExternalWorker): +async def test_activity_dynamic_duplicate( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn(dynamic=True) async def dyn_activity_1(args: Sequence[RawValue]) -> None: pass @@ -1333,7 +1550,11 @@ async def dyn_activity_2(args: Sequence[RawValue]) -> None: with pytest.raises(TypeError) as err: await _execute_workflow_with_activity( - client, worker, dyn_activity_1, additional_activities=[dyn_activity_2] + client, + worker, + dyn_activity_1, + additional_activities=[dyn_activity_2], + shared_state_manager=shared_state_manager, ) assert "More than one dynamic activity" in str(err.value) @@ -1350,6 +1571,7 @@ async def _execute_workflow_with_activity( worker: ExternalWorker, fn: Callable, *args: Any, + shared_state_manager: SharedStateManager, count: Optional[int] = None, index_as_arg: Optional[bool] = None, schedule_to_close_timeout_ms: Optional[int] = None, @@ -1369,7 +1591,7 @@ async def _execute_workflow_with_activity( worker_config["client"] = client worker_config["task_queue"] = str(uuid.uuid4()) worker_config["activities"] = [fn] + additional_activities - worker_config["shared_state_manager"] = _default_shared_state_manager + worker_config["shared_state_manager"] = shared_state_manager if not worker_config.get("max_concurrent_activities"): worker_config["max_concurrent_activities"] = default_max_concurrent_activities async with Worker(**worker_config): @@ -1441,7 +1663,9 @@ def emit(self, record: logging.LogRecord) -> None: async def test_activity_failure_trace_identifier( - client: Client, worker: ExternalWorker + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, ): @activity.defn async def raise_error(): @@ -1452,7 +1676,12 @@ async def raise_error(): try: with pytest.raises(WorkflowFailureError) as err: - await _execute_workflow_with_activity(client, worker, raise_error) + await _execute_workflow_with_activity( + client, + worker, + raise_error, + shared_state_manager=shared_state_manager, + ) assert ( str(assert_activity_application_error(err.value)) == "RuntimeError: oh no!" ) @@ -1462,7 +1691,11 @@ async def raise_error(): activity.logger.base_logger.removeHandler(CustomLogHandler()) -async def test_activity_heartbeat_context(client: Client, worker: ExternalWorker): +async def test_activity_heartbeat_context( + client: Client, + worker: ExternalWorker, + shared_state_manager: SharedStateManager, +): @activity.defn async def heartbeat(): if activity.info().attempt == 1: @@ -1484,13 +1717,20 @@ async def h(): return "details: " + activity.info().heartbeat_details[0] result = await _execute_workflow_with_activity( - client, worker, heartbeat, retry_max_attempts=2 + client, + worker, + heartbeat, + retry_max_attempts=2, + shared_state_manager=shared_state_manager, ) assert result.result == "details: Some detail" async def test_activity_reset_catch( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): if env.supports_time_skipping: pytest.skip("Time skipping server doesn't support activity reset") @@ -1541,6 +1781,7 @@ def sync_wait_cancel() -> str: client, worker, wait_cancel, + shared_state_manager=shared_state_manager, ) assert result.result == "Got cancelled error, reset? True" @@ -1552,12 +1793,16 @@ def sync_wait_cancel() -> str: worker, sync_wait_cancel, worker_config=config, + shared_state_manager=shared_state_manager, ) assert result.result == "Got cancelled error, reset? True" async def test_activity_reset_history( - client: Client, worker: ExternalWorker, env: WorkflowEnvironment + client: Client, + worker: ExternalWorker, + env: WorkflowEnvironment, + shared_state_manager: SharedStateManager, ): if env.supports_time_skipping: pytest.skip("Time skipping server doesn't support activity reset") @@ -1582,6 +1827,7 @@ async def wait_cancel() -> str: client, worker, wait_cancel, + shared_state_manager=shared_state_manager, ) assert isinstance(e.value.cause, ActivityError) assert isinstance(e.value.cause.cause, ApplicationError) From 359b6580e769c889a77dd38bf28a1b2addd807d1 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Tue, 25 Nov 2025 11:43:32 -0800 Subject: [PATCH 3/3] run formatter --- tests/conftest.py | 2 +- tests/worker/test_activity.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bf0f6856e..d45e85e14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,8 @@ from temporalio.client import Client from temporalio.testing import WorkflowEnvironment -from tests.helpers.worker import ExternalPythonWorker, ExternalWorker from temporalio.worker import SharedStateManager +from tests.helpers.worker import ExternalPythonWorker, ExternalWorker from . import DEV_SERVER_DOWNLOAD_VERSION diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index ef040071a..7dfc70e74 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -18,9 +18,9 @@ import pytest -import temporalio.exceptions import temporalio.api.common.v1 import temporalio.api.workflowservice.v1 +import temporalio.exceptions from temporalio import activity, workflow from temporalio.client import ( AsyncActivityHandle, @@ -53,7 +53,6 @@ kitchen_sink_retry_policy, ) - default_max_concurrent_activities = 50