Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests of asyncio.Lock and asyncio.Semaphore usage #567

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from contextlib import closing
from datetime import timedelta
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar

from temporalio.api.common.v1 import WorkflowExecution
from temporalio.api.enums.v1 import IndexedValueType
Expand All @@ -14,11 +14,12 @@
)
from temporalio.api.update.v1 import UpdateRef
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
from temporalio.client import BuildIdOpAddNewDefault, Client
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
from temporalio.common import SearchAttributeKey
from temporalio.service import RPCError, RPCStatusCode
from temporalio.worker import Worker, WorkflowRunner
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
from temporalio.workflow import UpdateMethodMultiParam


def new_worker(
Expand Down Expand Up @@ -128,3 +129,24 @@ async def workflow_update_exists(
if err.status != RPCStatusCode.NOT_FOUND:
raise
return False


# TODO: type update return value
async def admitted_update_task(
client: Client,
handle: WorkflowHandle,
update_method: UpdateMethodMultiParam,
id: str,
**kwargs,
) -> asyncio.Task:
"""
Return an asyncio.Task for an update after waiting for it to be admitted.
"""
update_task = asyncio.create_task(
handle.execute_update(update_method, id=id, **kwargs)
)
await assert_eq_eventually(
True,
lambda: workflow_update_exists(client, handle.id, id),
)
return update_task
300 changes: 300 additions & 0 deletions tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
WorkflowRunner,
)
from tests.helpers import (
admitted_update_task,
assert_eq_eventually,
ensure_search_attributes_present,
find_free_port,
Expand Down Expand Up @@ -5505,3 +5506,302 @@ def _unfinished_handler_warning_cls(self) -> Type:
"update": workflow.UnfinishedUpdateHandlersWarning,
"signal": workflow.UnfinishedSignalHandlersWarning,
}[self.handler_type]


# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
# asyncio.Semaphore are used here.


@activity.defn
async def noop_activity_for_lock_or_semaphore_tests() -> None:
return None


@dataclass
class LockOrSemaphoreWorkflowConcurrencySummary:
ever_in_critical_section: int
peak_in_critical_section: int


@dataclass
class UseLockOrSemaphoreWorkflowParameters:
n_coroutines: int = 0
semaphore_initial_value: Optional[int] = None
sleep: Optional[float] = None
timeout: Optional[float] = None


@workflow.defn
class CoroutinesUseLockOrSemaphoreWorkflow:
def __init__(self) -> None:
self.params: UseLockOrSemaphoreWorkflowParameters
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
self._currently_in_critical_section: set[str] = set()
self._ever_in_critical_section: set[str] = set()
self._peak_in_critical_section = 0

def init(self, params: UseLockOrSemaphoreWorkflowParameters):
self.params = params
if self.params.semaphore_initial_value is not None:
self.lock_or_semaphore = asyncio.Semaphore(
self.params.semaphore_initial_value
)
else:
self.lock_or_semaphore = asyncio.Lock()

@workflow.run
async def run(
self,
params: Optional[UseLockOrSemaphoreWorkflowParameters],
) -> LockOrSemaphoreWorkflowConcurrencySummary:
# TODO: Use workflow init method when it exists.
assert params
self.init(params)
await asyncio.gather(
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
)
assert not any(self._currently_in_critical_section)
return LockOrSemaphoreWorkflowConcurrencySummary(
len(self._ever_in_critical_section),
self._peak_in_critical_section,
)

async def coroutine(self, id: str):
if self.params.timeout:
try:
await asyncio.wait_for(
self.lock_or_semaphore.acquire(), self.params.timeout
)
except asyncio.TimeoutError:
return
else:
await self.lock_or_semaphore.acquire()
self._enters_critical_section(id)
try:
if self.params.sleep:
await asyncio.sleep(self.params.sleep)
else:
await workflow.execute_activity(
noop_activity_for_lock_or_semaphore_tests,
schedule_to_close_timeout=timedelta(seconds=30),
)
finally:
self.lock_or_semaphore.release()
self._exits_critical_section(id)

def _enters_critical_section(self, id: str) -> None:
self._currently_in_critical_section.add(id)
self._ever_in_critical_section.add(id)
self._peak_in_critical_section = max(
self._peak_in_critical_section,
len(self._currently_in_critical_section),
)

def _exits_critical_section(self, id: str) -> None:
self._currently_in_critical_section.remove(id)


@workflow.defn
class HandlerCoroutinesUseLockOrSemaphoreWorkflow(CoroutinesUseLockOrSemaphoreWorkflow):
def __init__(self) -> None:
super().__init__()
self.workflow_may_exit = False

@workflow.run
async def run(
self,
_: Optional[UseLockOrSemaphoreWorkflowParameters] = None,
) -> LockOrSemaphoreWorkflowConcurrencySummary:
await workflow.wait_condition(lambda: self.workflow_may_exit)
return LockOrSemaphoreWorkflowConcurrencySummary(
len(self._ever_in_critical_section),
self._peak_in_critical_section,
)

@workflow.update
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
# TODO: Use workflow init method when it exists.
if not hasattr(self, "params"):
self.init(params)
assert (update_info := workflow.current_update_info())
await self.coroutine(update_info.id)

@workflow.signal
async def finish(self):
self.workflow_may_exit = True


async def _do_workflow_coroutines_lock_or_semaphore_test(
client: Client,
params: UseLockOrSemaphoreWorkflowParameters,
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
):
async with new_worker(
client,
CoroutinesUseLockOrSemaphoreWorkflow,
activities=[noop_activity_for_lock_or_semaphore_tests],
) as worker:
summary = await client.execute_workflow(
CoroutinesUseLockOrSemaphoreWorkflow.run,
arg=params,
id=str(uuid.uuid4()),
task_queue=worker.task_queue,
)
assert summary == expectation


async def _do_update_handler_lock_or_semaphore_test(
client: Client,
env: WorkflowEnvironment,
params: UseLockOrSemaphoreWorkflowParameters,
n_updates: int,
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
):
if env.supports_time_skipping:
pytest.skip(
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
)

task_queue = "tq"
handle = await client.start_workflow(
HandlerCoroutinesUseLockOrSemaphoreWorkflow.run,
id=f"wf-{str(uuid.uuid4())}",
task_queue=task_queue,
)
# Create updates in Admitted state, before the worker starts polling.
admitted_updates = [
await admitted_update_task(
client,
handle,
HandlerCoroutinesUseLockOrSemaphoreWorkflow.my_update,
arg=params,
id=f"update-{i}",
)
for i in range(n_updates)
]
async with new_worker(
client,
HandlerCoroutinesUseLockOrSemaphoreWorkflow,
activities=[noop_activity_for_lock_or_semaphore_tests],
task_queue=task_queue,
):
for update_task in admitted_updates:
await update_task
await handle.signal(HandlerCoroutinesUseLockOrSemaphoreWorkflow.finish)
summary = await handle.result()
assert summary == expectation


async def test_workflow_coroutines_can_use_lock(client: Client):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
# The lock limits concurrency to 1
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=1
),
)


async def test_update_handler_can_use_lock_to_serialize_handler_executions(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
UseLockOrSemaphoreWorkflowParameters(),
n_updates=5,
# The lock limits concurrency to 1
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=1
),
)


async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=1, peak_in_critical_section=1
),
)


async def test_update_handler_lock_acquisition_respects_timeout(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
n_updates=5,
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=1, peak_in_critical_section=1
),
)


async def test_workflow_coroutines_can_use_semaphore(client: Client):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
# The semaphore limits concurrency to 3
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=3
),
)


async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
# The semaphore limits concurrency to 3
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
n_updates=5,
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=5, peak_in_critical_section=3
),
)


async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
client: Client,
):
await _do_workflow_coroutines_lock_or_semaphore_test(
client,
UseLockOrSemaphoreWorkflowParameters(
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
),
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
# slot fail.
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=3, peak_in_critical_section=3
),
)


async def test_update_handler_semaphore_acquisition_respects_timeout(
client: Client, env: WorkflowEnvironment
):
await _do_update_handler_lock_or_semaphore_test(
client,
env,
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
# slot fail.
UseLockOrSemaphoreWorkflowParameters(
semaphore_initial_value=3,
sleep=0.5,
timeout=0.1,
),
n_updates=5,
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
ever_in_critical_section=3, peak_in_critical_section=3
),
)
Loading