diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index a62195469c2..7e6e9851568 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -85,7 +85,7 @@ def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]: return tuple(self._lifespan_tasks) @contextlib.asynccontextmanager - async def _run_lifespan_tasks(self, app: Starlette): + async def _run_lifespan_tasks(self, starlette_app: Starlette): self._lifespan_tasks_started = True running_tasks = [] try: @@ -98,7 +98,11 @@ async def _run_lifespan_tasks(self, app: Starlette): else: signature = inspect.signature(task) if "app" in signature.parameters: - task = functools.partial(task, app=app) + task = functools.partial(task, app=self) + if "starlette_app" in signature.parameters: + task = functools.partial( + task, starlette_app=starlette_app + ) t_ = task() if isinstance(t_, contextlib._AsyncGeneratorContextManager): await stack.enter_async_context(t_) diff --git a/tests/units/app_mixins/test_lifespan.py b/tests/units/app_mixins/test_lifespan.py new file mode 100644 index 00000000000..5da2be7a015 --- /dev/null +++ b/tests/units/app_mixins/test_lifespan.py @@ -0,0 +1,36 @@ +"""Unit tests for lifespan mixin behavior.""" + +from __future__ import annotations + +import asyncio + +import pytest +from starlette.applications import Starlette + +from reflex.app_mixins.lifespan import LifespanMixin + + +@pytest.mark.asyncio +async def test_lifespan_task_app_param_receives_reflex_app_instance(): + """Lifespan tasks should receive the Reflex app instance, not Starlette.""" + + class DummyApp(LifespanMixin): + """Minimal test app based on the lifespan mixin.""" + + app = DummyApp() + received: dict[str, object] = {} + + def lifespan_task(app): + """Record the app argument injected by the lifespan runner. + + Args: + app: App object injected by the lifespan runner. + """ + received["app"] = app + + app.register_lifespan_task(lifespan_task) + + async with app._run_lifespan_tasks(Starlette()): + await asyncio.sleep(0) + + assert received["app"] is app