Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/workflows/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from workflows.types import RunResultT

from .serializers import BaseSerializer, JsonSerializer
from .state_store import InMemoryStateStore, MODEL_T, DictState
from .state_store import MODEL_T, DictState, InMemoryStateStore

if TYPE_CHECKING: # pragma: no cover
from workflows import Workflow
Expand Down Expand Up @@ -261,7 +261,9 @@ def from_dict(
msg = "Error creating a Context instance: the provided payload has a wrong or old format."
raise ContextSerdeError(msg) from e

async def set(self, key: str, value: Any, make_private: bool = False) -> None:
async def set(
self, key: str, value: Any, make_private: bool = False
) -> None: # pragma: no cover
"""
Store `value` into the Context under `key`.

Expand Down Expand Up @@ -518,9 +520,9 @@ async def wait_for_event(

# send the waiter event if it's not already sent
if waiter_event is not None:
is_waiting = await self.get(waiter_id, default=False)
is_waiting = await self.store.get(waiter_id, default=False)
if not is_waiting:
await self.set(waiter_id, True)
await self.store.set(waiter_id, True)
self.write_event_to_stream(waiter_event)

while True:
Expand All @@ -536,7 +538,7 @@ async def wait_for_event(
else:
continue
finally:
await self.set(waiter_id, False)
await self.store.set(waiter_id, False)

def write_event_to_stream(self, ev: Event | None) -> None:
self._streaming_queue.put_nowait(ev)
Expand Down
2 changes: 1 addition & 1 deletion src/workflows/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

try:
from typing import Union
except ImportError:
except ImportError: # pragma: no cover
from typing_extensions import Union

from .events import StopEvent
Expand Down
Empty file added tests/context/__init__.py
Empty file.
71 changes: 18 additions & 53 deletions tests/test_context.py → tests/context/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from workflows.context import Context
from workflows.context.state_store import DictState
from workflows.decorators import StepConfig, step
from workflows.errors import ContextSerdeError, WorkflowRuntimeError
from workflows.errors import WorkflowRuntimeError
from workflows.events import (
Event,
HumanResponseEvent,
Expand All @@ -30,7 +30,7 @@
)
from workflows.workflow import Workflow

from .conftest import AnotherTestEvent, OneTestEvent
from ..conftest import AnotherTestEvent, OneTestEvent


@pytest.mark.asyncio
Expand Down Expand Up @@ -64,26 +64,19 @@ async def step3(
@pytest.mark.asyncio
async def test_get_default(workflow: Workflow) -> None:
c1: Context[DictState] = Context(workflow)
assert await c1.get(key="test_key", default=42) == 42
assert await c1.store.get("test_key", default=42) == 42


@pytest.mark.asyncio
async def test_get(ctx: Context) -> None:
await ctx.set("foo", 42)
assert await ctx.get("foo") == 42
await ctx.store.set("foo", 42)
assert await ctx.store.get("foo") == 42


@pytest.mark.asyncio
async def test_get_not_found(ctx: Context) -> None:
with pytest.raises(ValueError):
await ctx.get("foo")


@pytest.mark.asyncio
async def test_legacy_data(workflow: Workflow) -> None:
c1: Context[DictState] = Context(workflow)
await c1.set(key="test_key", value=42)
assert await c1.get("test_key") == 42
await ctx.store.get("foo")


def test_send_event_step_is_none(ctx: Context) -> None:
Expand Down Expand Up @@ -149,20 +142,13 @@ def test_to_dict_with_events_buffer(ctx: Context) -> None:
assert json.dumps(ctx.to_dict())


@pytest.mark.asyncio
async def test_deprecated_params(ctx: Context) -> None:
with pytest.warns(
DeprecationWarning, match="`make_private` is deprecated and will be ignored"
):
await ctx.set("foo", 42, make_private=True)


@pytest.mark.asyncio
async def test_empty_inprogress_when_workflow_done(workflow: Workflow) -> None:
h = workflow.run()
_ = await h

# there shouldn't be any in progress events
assert h.ctx is not None
for inprogress_list in h.ctx._in_progress.values():
assert len(inprogress_list) == 0

Expand Down Expand Up @@ -220,7 +206,7 @@ async def step1(self, ctx: Context, ev: StartEvent) -> StopEvent:
@pytest.mark.asyncio
async def test_prompt_and_wait(ctx: Context) -> None:
prompt_id = "test_prompt_and_wait"
prompt_event = InputRequiredEvent(prefix="test_prompt_and_wait")
prompt_event = InputRequiredEvent(prefix="test_prompt_and_wait") # type: ignore
expected_event = HumanResponseEvent
requirements = {"waiter_id": "test_prompt_and_wait"}
timeout = 10
Expand All @@ -235,7 +221,7 @@ async def test_prompt_and_wait(ctx: Context) -> None:
)
)
await asyncio.sleep(0.01)
ctx.send_event(HumanResponseEvent(response="foo", waiter_id="test_prompt_and_wait"))
ctx.send_event(HumanResponseEvent(response="foo", waiter_id="test_prompt_and_wait")) # type: ignore

result = await waiting_task
assert result.response == "foo"
Expand Down Expand Up @@ -264,7 +250,7 @@ async def spawn_waiters(

@step
async def waiter_one(self, ctx: Context, ev: Waiter1) -> ResultEvent:
ctx.write_event_to_stream(InputRequiredEvent(prefix="waiter_one"))
ctx.write_event_to_stream(InputRequiredEvent(prefix="waiter_one")) # type: ignore

new_ev: HumanResponseEvent = await ctx.wait_for_event(
HumanResponseEvent,
Expand All @@ -274,7 +260,7 @@ async def waiter_one(self, ctx: Context, ev: Waiter1) -> ResultEvent:

@step
async def waiter_two(self, ctx: Context, ev: Waiter2) -> ResultEvent:
ctx.write_event_to_stream(InputRequiredEvent(prefix="waiter_two"))
ctx.write_event_to_stream(InputRequiredEvent(prefix="waiter_two")) # type: ignore

new_ev: HumanResponseEvent = await ctx.wait_for_event(
HumanResponseEvent,
Expand Down Expand Up @@ -302,11 +288,11 @@ async def test_wait_for_multiple_events_in_workflow() -> None:
async for ev in handler.stream_events():
if isinstance(ev, InputRequiredEvent) and ev.prefix == "waiter_one":
handler.ctx.send_event(
HumanResponseEvent(response="foo", waiter_id="waiter_one")
HumanResponseEvent(response="foo", waiter_id="waiter_one") # type: ignore
)
elif isinstance(ev, InputRequiredEvent) and ev.prefix == "waiter_two":
handler.ctx.send_event(
HumanResponseEvent(response="bar", waiter_id="waiter_two")
HumanResponseEvent(response="bar", waiter_id="waiter_two") # type: ignore
)

result = await handler
Expand All @@ -321,11 +307,11 @@ async def test_wait_for_multiple_events_in_workflow() -> None:
async for ev in handler.stream_events():
if isinstance(ev, InputRequiredEvent) and ev.prefix == "waiter_one":
handler.ctx.send_event(
HumanResponseEvent(response="fizz", waiter_id="waiter_one")
HumanResponseEvent(response="fizz", waiter_id="waiter_one") # type: ignore
)
elif isinstance(ev, InputRequiredEvent) and ev.prefix == "waiter_two":
handler.ctx.send_event(
HumanResponseEvent(response="buzz", waiter_id="waiter_two")
HumanResponseEvent(response="buzz", waiter_id="waiter_two") # type: ignore
)

result = await handler
Expand All @@ -339,28 +325,7 @@ def test_get_holding_events(ctx: Context) -> None:

@pytest.mark.asyncio
async def test_clear(ctx: Context) -> None:
await ctx.set("test_key", 42)
ctx.clear()
res = await ctx.get("test_key", default=None)
await ctx.store.set("test_key", 42)
await ctx.store.clear()
res = await ctx.store.get("test_key", default=None)
assert res is None


def test_serialization_roundtrip(ctx: Context, workflow: Workflow) -> None:
assert Context.from_dict(workflow, ctx.to_dict())


def test_old_serialization(ctx: Context, workflow: Workflow) -> None:
old_payload = {
"globals": {},
"streaming_queue": "[]",
"queues": {"test_id": "[]"},
"stepwise": False,
"events_buffer": {},
"in_progress": {},
"accepted_events": [],
"broker_log": [],
"waiter_id": "test_id",
"is_running": False,
}
with pytest.raises(ContextSerdeError):
Context.from_dict(workflow, old_payload)
31 changes: 31 additions & 0 deletions tests/context/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 LlamaIndex Inc.

from __future__ import annotations

import pytest

from workflows.context import Context
from workflows.errors import ContextSerdeError
from workflows.workflow import Workflow


def test_serialization_roundtrip(ctx: Context, workflow: Workflow) -> None:
assert Context.from_dict(workflow, ctx.to_dict())


def test_old_serialization(ctx: Context, workflow: Workflow) -> None:
old_payload = {
"globals": {},
"streaming_queue": "[]",
"queues": {"test_id": "[]"},
"stepwise": False,
"events_buffer": {},
"in_progress": {},
"accepted_events": [],
"broker_log": [],
"waiter_id": "test_id",
"is_running": False,
}
with pytest.raises(ContextSerdeError):
Context.from_dict(workflow, old_payload)
38 changes: 38 additions & 0 deletions tests/context/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from workflows.context.utils import (
get_qualified_name,
import_module_from_qualified_name,
)


def test_get_qualified_name() -> None:
with pytest.raises(
AttributeError,
match="Object foo does not have required attributes: 'str' object has no attribute '__module__'",
):
get_qualified_name("foo")


def test_import_module_from_qualified_name_wrong_name() -> None:
with pytest.raises(
ValueError, match="Qualified name must be in format 'module.attribute'"
):
import_module_from_qualified_name("not containing a dot")
import_module_from_qualified_name("")


def test_import_module_from_qualified_name_wrong_package() -> None:
with pytest.raises(
ImportError,
match="Failed to import module __doesnt: No module named '__doesnt'",
):
import_module_from_qualified_name("__doesnt.exist")


def test_import_module_from_qualified_name_wrong_module() -> None:
with pytest.raises(
AttributeError,
match="Attribute doesntexist not found in module typing: module 'typing' has no attribute 'doesntexist'",
):
import_module_from_qualified_name("typing.doesntexist")
9 changes: 8 additions & 1 deletion tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from workflows.decorators import step
from workflows.events import Event, StartEvent, StopEvent
from workflows.resource import Resource
from workflows.resource import Resource, ResourceManager
from workflows.workflow import Workflow


Expand Down Expand Up @@ -180,3 +180,10 @@ async def test_step_2(
await wf_1.run()
assert cc1 == 1 # type: ignore
assert cc2 == 1 # type: ignore


@pytest.mark.asyncio
async def test_resource_manager() -> None:
m = ResourceManager()
await m.set("test_resource", 42)
assert m.get_all() == {"test_resource": 42}
6 changes: 3 additions & 3 deletions tests/test_retry_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ class DummyWorkflow(Workflow):
# Set a small delay to avoid impacting the CI speed too much
@step(retry_policy=ConstantDelayRetryPolicy(delay=0.2))
async def flaky_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
count = await ctx.get("counter", default=0)
count = await ctx.store.get("counter", default=0)
ctx.send_event(CountEvent())
if count < 3:
raise ValueError("Something bad happened!")
return StopEvent(result="All good!")

@step
async def counter(self, ctx: Context, ev: CountEvent) -> None:
count = await ctx.get("counter", default=0)
await ctx.set("counter", count + 1)
count = await ctx.store.get("counter", default=0)
await ctx.store.set("counter", count + 1)

workflow = DummyWorkflow(disable_validation=True)
assert await workflow.run() == "All good!"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class CounterWorkflow(Workflow):
async def count(self, ctx: Context, ev: StartEvent) -> StopEvent:
ctx.write_event_to_stream(Event(msg="hello!"))

cur_count = await ctx.get("cur_count", default=0)
await ctx.set("cur_count", cur_count + 1)
cur_count = await ctx.store.get("cur_count", default=0)
await ctx.store.set("cur_count", cur_count + 1)
return StopEvent(result="done")

wf = CounterWorkflow()
Expand All @@ -158,4 +158,4 @@ async def count(self, ctx: Context, ev: StartEvent) -> StopEvent:
await handler_2

assert handler_2.ctx
assert await handler_2.ctx.get("cur_count") == 2
assert await handler_2.ctx.store.get("cur_count") == 2
5 changes: 5 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,8 @@ def test_is_free_function() -> None:
assert is_free_function("some_function.<locals>.MyClass.my_function") is False
with pytest.raises(ValueError):
is_free_function("")


def test_inspect_signature_raises_if_not_callable() -> None:
with pytest.raises(TypeError, match="Expected a callable object, got str"):
inspect_signature("foo") # type: ignore
Loading
Loading