Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move dg_stack into a ContextVar to support with blocks in separate threads #7715

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 7 additions & 9 deletions lib/streamlit/delta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
from streamlit.proto.RootContainer_pb2 import RootContainer
from streamlit.runtime import caching, legacy_caching
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.runtime.scriptrunner.script_run_context import dg_stack
from streamlit.runtime.state import NoValue

if TYPE_CHECKING:
Expand Down Expand Up @@ -282,9 +283,7 @@ def __repr__(self) -> str:

def __enter__(self) -> None:
# with block started
ctx = get_script_run_ctx()
if ctx:
ctx.dg_stack.append(self)
dg_stack.set(dg_stack.get() + (self,))

def __exit__(
self,
Expand All @@ -293,9 +292,8 @@ def __exit__(
traceback: Any,
) -> Literal[False]:
# with block ended
ctx = get_script_run_ctx()
if ctx is not None:
ctx.dg_stack.pop()

dg_stack.set(dg_stack.get()[:-1])

# Re-raise any exceptions
return False
Expand All @@ -310,9 +308,9 @@ def _active_dg(self) -> DeltaGenerator:
if self == self._main_dg:
# We're being invoked via an `st.foo` pattern - use the current
# `with` dg (aka the top of the stack).
ctx = get_script_run_ctx()
if ctx and len(ctx.dg_stack) > 0:
return ctx.dg_stack[-1]
current_stack = dg_stack.get()
if len(current_stack) > 0:
return current_stack[-1]

# We're being invoked via an `st.sidebar.foo` pattern - ignore the
# current `with` dg.
Expand Down
7 changes: 2 additions & 5 deletions lib/streamlit/elements/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from streamlit.proto import Block_pb2
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner import ScriptRunContext, get_script_run_ctx
from streamlit.runtime.scriptrunner.script_run_context import dg_stack
from streamlit.runtime.state import WidgetArgs, WidgetCallback, WidgetKwargs

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,11 +53,7 @@ def _current_form(this_dg: DeltaGenerator) -> FormData | None:
if this_dg == this_dg._main_dg:
# We were created via an `st.foo` call.
# Walk up the dg_stack to see if we're nested inside a `with st.form` statement.
ctx = get_script_run_ctx()
if ctx is None or len(ctx.dg_stack) == 0:
return None

for dg in reversed(ctx.dg_stack):
for dg in reversed(dg_stack.get()):
if dg._form_data is not None:
return dg._form_data
else:
Expand Down
18 changes: 13 additions & 5 deletions lib/streamlit/runtime/scriptrunner/script_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import collections
import contextvars
import threading
from dataclasses import dataclass, field
from typing import Callable, Counter, Dict, List, Optional, Set
from typing import Callable, Counter, Dict, List, Optional, Set, Tuple

from typing_extensions import Final, TypeAlias

Expand All @@ -33,14 +34,24 @@
UserInfo: TypeAlias = Dict[str, Optional[str]]


# The dg_stack tracks the currently active DeltaGenerator, and is pushed to when
# a DeltaGenerator is entered via a `with` block. This is implemented as a ContextVar
# so that different threads or async tasks can have their own stacks.
dg_stack: contextvars.ContextVar[
Tuple["streamlit.delta_generator.DeltaGenerator", ...]
] = contextvars.ContextVar("dg_stack", default=tuple())


@dataclass
class ScriptRunContext:
"""A context object that contains data for a "script run" - that is,
data that's scoped to a single ScriptRunner execution (and therefore also
scoped to a single connected "session").

ScriptRunContext is used internally by virtually every `st.foo()` function.
It is accessed only from the script thread that's created by ScriptRunner.
It is accessed only from the script thread that's created by ScriptRunner,
or from app-created helper threads that have been "attached" to the
ScriptRunContext via `add_script_run_ctx`.

Streamlit code typically retrieves the active ScriptRunContext via the
`get_script_run_ctx` function.
Expand All @@ -64,9 +75,6 @@ class ScriptRunContext:
widget_user_keys_this_run: Set[str] = field(default_factory=set)
form_ids_this_run: Set[str] = field(default_factory=set)
cursors: Dict[int, "streamlit.cursor.RunningCursor"] = field(default_factory=dict)
dg_stack: List["streamlit.delta_generator.DeltaGenerator"] = field(
default_factory=list
)
script_requests: Optional[ScriptRequests] = None

def reset(self, query_string: str = "", page_script_hash: str = "") -> None:
Expand Down
118 changes: 118 additions & 0 deletions lib/tests/streamlit/delta_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

"""DeltaGenerator Unittest."""

import asyncio
import functools
import inspect
import json
import logging
import re
import threading
import unittest
from unittest.mock import MagicMock, patch

Expand All @@ -38,6 +40,7 @@
from streamlit.proto.Text_pb2 import Text as TextProto
from streamlit.proto.TextArea_pb2 import TextArea
from streamlit.proto.TextInput_pb2 import TextInput
from streamlit.runtime.scriptrunner import add_script_run_ctx
from streamlit.runtime.state.common import compute_widget_id
from streamlit.runtime.state.widgets import _build_duplicate_widget_message
from tests.delta_generator_test_case import DeltaGeneratorTestCase
Expand Down Expand Up @@ -476,6 +479,121 @@ def test_nested_with(self):
msg.metadata.delta_path,
)

def test_threads_with(self):
"""
Tests that with statements work correctly when multiple threads are involved.

The test sequence is as follows:

Main Thread | Worker Thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent use of diagram explanation

-----------------------------------------------------
with container1: |
| with container2:
st.markdown("Object 1") |
| st.markdown("Object 2")


We check that Object1 is created in container1 and object2 is created in container2.
"""
container1 = st.container()
container2 = st.container()

with_1 = threading.Event()
with_2 = threading.Event()
object_1 = threading.Event()

def thread():
with_1.wait()
with container2:
with_2.set()
object_1.wait()

st.markdown("Object 2")
msg = self.get_message_from_queue()
self.assertEqual(
make_delta_path(RootContainer.MAIN, (1,), 0),
msg.metadata.delta_path,
)

worker_thread = threading.Thread(target=thread)
add_script_run_ctx(worker_thread)
worker_thread.start()

with container1:
with_1.set()
with_2.wait()

st.markdown("Object in container 1")
msg = self.get_message_from_queue()
self.assertEqual(
make_delta_path(RootContainer.MAIN, (0,), 0),
msg.metadata.delta_path,
)

object_1.set()
worker_thread.join()

def test_asyncio_with(self):
"""
Tests that with statements work correctly when multiple async tasks are involved.

The test sequence is as follows:

Task 1 | Task 2
-----------------------------------------------------
with container1:
asyncio.create_task() ->
| st.markdown("Object 1a")
| with container2:
st.markdown("Object 1b") |
| st.markdown("Object 2")

In this scenario, Task 2 should inherit the container1 context from Task 1 when it is created, so Objects 1a and 1b
will both go in container 1, and object 2 will go in container 2.
"""
container1 = st.container()
container2 = st.container()

with_2 = asyncio.Event()
object_1 = asyncio.Event()

async def task1():
with container1:
task = asyncio.create_task(task2())

await with_2.wait()

st.markdown("Object 1b")
msg = self.get_message_from_queue()
self.assertEqual(
make_delta_path(RootContainer.MAIN, (0,), 1),
msg.metadata.delta_path,
)

object_1.set()
await task

async def task2():
st.markdown("Object 1a")
msg = self.get_message_from_queue()
self.assertEqual(
make_delta_path(RootContainer.MAIN, (0,), 0),
msg.metadata.delta_path,
)

with container2:
with_2.set()
st.markdown("Object 2")
msg = self.get_message_from_queue()
self.assertEqual(
make_delta_path(RootContainer.MAIN, (1,), 0),
msg.metadata.delta_path,
)

await object_1.wait()

asyncio.get_event_loop().run_until_complete(task1())


class DeltaGeneratorWriteTest(DeltaGeneratorTestCase):
"""Test DeltaGenerator Text, Alert, Json, and Markdown Classes."""
Expand Down