Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing on_change_callback for CustomComponents #8633

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions e2e_playwright/custom_components/popular_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,22 @@ def use_folium():
def use_option_menu():
from streamlit_option_menu import option_menu

key = "my_option_menu"

# TODO: uncomment the on_change callback as soon as streamlit-option-menu is updated and uses the new on_change callback
# def on_change():
# selection = st.session_state[key]
# st.write(f"Selection changed to {selection}")

with st.sidebar:
selected = option_menu(
"Main Menu",
["Home", "Settings"],
icons=["house", "gear"],
menu_icon="cast",
default_index=1,
key=key,
# on_change=on_change,
)
st.write(selected)

Expand Down
7 changes: 7 additions & 0 deletions e2e_playwright/custom_components/popular_components_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def test_option_menu(app: Page):
_expect_no_exception(app)
_expect_iframe_attached(app)

# TODO: uncomment the on_change callback as soon as streamlit-option-menu is updated and uses the new on_change callback
# frame_locator = app.frame_locator("iframe")
# frame_locator.locator("a", has_text="Home").click()
# expect(
# app.get_by_test_id("stMarkdown").filter(has_text="Selection changed to Home")
# ).to_be_visible()


def test_url_fragment(app: Page):
"""Test that the url-fragment component renders"""
Expand Down
17 changes: 15 additions & 2 deletions lib/streamlit/components/types/base_custom_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

import os
from abc import ABC, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Any

from streamlit import util
from streamlit.errors import StreamlitAPIException

if TYPE_CHECKING:
from streamlit.runtime.state.common import WidgetCallback


class MarshallComponentException(StreamlitAPIException):
"""Class for exceptions generated during custom component marshalling."""
Expand Down Expand Up @@ -56,10 +59,17 @@ def __call__(
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""An alias for create_instance."""
return self.create_instance(*args, default=default, key=key, **kwargs)
return self.create_instance(
*args,
default=default,
key=key,
on_change=on_change,
**kwargs,
)

@property
def abspath(self) -> str | None:
Expand Down Expand Up @@ -102,6 +112,7 @@ def create_instance(
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""Create a new instance of the component.
Expand All @@ -118,6 +129,8 @@ def create_instance(
key: str or None
If not None, this is the user key we use to generate the
component's "widget ID".
on_change: WidgetCallback or None
An optional callback invoked when the widget's value changes. No arguments are passed to it.
**kwargs
Keyword args to pass to the component.

Expand Down
14 changes: 13 additions & 1 deletion lib/streamlit/components/v1/custom_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
from streamlit.runtime.state.common import WidgetCallback


class MarshallComponentException(StreamlitAPIException):
Expand All @@ -49,17 +50,25 @@ def __call__(
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""An alias for create_instance."""
return self.create_instance(*args, default=default, key=key, **kwargs)
return self.create_instance(
*args,
default=default,
key=key,
on_change=on_change,
**kwargs,
)

@gather_metrics("create_instance")
def create_instance(
self,
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""Create a new instance of the component.
Expand All @@ -76,6 +85,8 @@ def create_instance(
key: str or None
If not None, this is the user key we use to generate the
component's "widget ID".
on_change: WidgetCallback or None
An optional callback invoked when the widget's value changes. No arguments are passed to it.
**kwargs
Keyword args to pass to the component.

Expand Down Expand Up @@ -195,6 +206,7 @@ def deserialize_component(ui_value, widget_id=""):
deserializer=deserialize_component,
serializer=lambda x: x,
ctx=ctx,
on_change_handler=on_change,
)
widget_value = component_state.value

Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/runtime/state/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def user_key_from_widget_id(widget_id: str) -> str | None:
"None" as a key, but we can't avoid this kind of problem while storing the
string representation of the no-user-key sentinel as part of the widget id.
"""
user_key = widget_id.split("-", maxsplit=2)[-1]
user_key: str | None = widget_id.split("-", maxsplit=2)[-1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an unrelated drive-by change since the linter complained about wrong typing

user_key = None if user_key == "None" else user_key
return user_key

Expand Down
38 changes: 38 additions & 0 deletions lib/tests/streamlit/components_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import json
import os
import threading
import unittest
from typing import Any
from unittest import mock
Expand All @@ -38,9 +39,11 @@
from streamlit.components.v1.custom_component import CustomComponent
from streamlit.errors import DuplicateWidgetID, StreamlitAPIException
from streamlit.proto.Components_pb2 import SpecialArg
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
from streamlit.runtime import Runtime, RuntimeConfig
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import ScriptRunContext, add_script_run_ctx
from streamlit.type_util import to_bytes
from tests.delta_generator_test_case import DeltaGeneratorTestCase

Expand Down Expand Up @@ -75,6 +78,10 @@ def setUp(self) -> None:
)
self.runtime = Runtime(config)

# declare_component needs a script_run_ctx to be set
self.script_run_ctx = MagicMock(spec=ScriptRunContext)
add_script_run_ctx(threading.current_thread(), self.script_run_ctx)
Comment on lines +81 to +83
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be added if you want to run the tests here standalone like pytest lib/tests/streamlit/components_test.py. In the CI it does not fail due to pure luck as apparently the order of execution of the tests ensures that a runtime exists 🫠


def tearDown(self) -> None:
Runtime._instance = None

Expand Down Expand Up @@ -484,6 +491,34 @@ def test_df_default(self):
proto.special_args[0],
)

def test_on_change_handler(self):
"""Test the 'on_change' callback param."""

# we use a list here so that we can update it in the lambda; we cannot assign a variable there.
callback_call_value = []
expected_element_value = "Called with foo"

def create_on_change_handler(some_arg: str):
return lambda: callback_call_value.append("Called with " + some_arg)

return_value = self.test_component(
key="key", default="baz", on_change=create_on_change_handler("foo")
)
self.assertEqual("baz", return_value)

proto = self.get_delta_from_queue().new_element.component_instance
self.assertJSONEqual({"key": "key", "default": "baz"}, proto.json_args)
current_widget_states = self.script_run_ctx.session_state.get_widget_states()
new_widget_state = WidgetState()
# copy the custom components state and update the value
new_widget_state.CopyFrom(current_widget_states[0])
# update the widget's value so that the rerun will execute the callback
new_widget_state.json_value = '{"key": "key", "default": "baz2"}'
self.script_run_ctx.session_state.on_script_will_rerun(
WidgetStates(widgets=[new_widget_state])
)
self.assertEqual(callback_call_value[0], expected_element_value)

def assertJSONEqual(self, a, b):
"""Asserts that two JSON dicts are equal. If either arg is a string,
it will be first converted to a dict with json.loads()."""
Expand Down Expand Up @@ -559,6 +594,9 @@ def get_component_path(self, name: str) -> str | None:
def get_module_name(self, name: str) -> str | None:
return None

def get_component(self, name: str) -> BaseCustomComponent | None:
return None

def get_components(self) -> list[BaseCustomComponent]:
return []

Expand Down
Loading