Skip to content

Commit

Permalink
Allow passing on_change_callback for CustomComponents (#8633)
Browse files Browse the repository at this point in the history
## Describe your changes

In the past, some custom components have used a
patch (https://gist.github.com/okld/1a2b2fd2cb9f85fc8c4e92e26c6597d5) to
register an on_change callback. Recently, we have done some refactoring
that broke this workaround. This PR is a suggestion to extend our
official API to make the patch redundant.

Note that we only want to pass the `on_change_callback` and not the
`args` and `kwargs`.
The `register_widget` function today uses `args` and `kwargs` as
keywords to pass to the `on_change callback`. Besides the unfortunate
naming - these are special keywords meant for functions themselves and
not for pass-through arguments - we are thinking about deprecating them
entirely, since you can wrap the callback easily to pass the arguments.

## GitHub Issue Link (if applicable)

Closes #3977
Related to victoryhb/streamlit-option-menu#70

## Testing Plan

- Explanation of why no additional tests are needed
- Unit Tests (JS and/or Python)
- A new unit test is added to make sure the on_change callback is called
when the value changes during a ScriptRun
- E2E Tests
  - prepare `on_change` callback test in the `option_menu` function
- Any manual testing needed?
- I manually tested it on the example of
[streamlit-option-menu](https://github.com/victoryhb/streamlit-option-menu)

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
raethlein committed May 22, 2024
1 parent 30d8bdf commit ad3dc4e
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 4 deletions.
9 changes: 9 additions & 0 deletions e2e_playwright/custom_components/popular_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,22 @@ def use_folium():
def use_option_menu():
from streamlit_option_menu import option_menu

key = "my_option_menu"

# TODO: uncomment the on_change callback as soon as streamlit-option-menu is updated and uses the new on_change callback
# def on_change():
# selection = st.session_state[key]
# st.write(f"Selection changed to {selection}")

with st.sidebar:
selected = option_menu(
"Main Menu",
["Home", "Settings"],
icons=["house", "gear"],
menu_icon="cast",
default_index=1,
key=key,
# on_change=on_change,
)
st.write(selected)

Expand Down
7 changes: 7 additions & 0 deletions e2e_playwright/custom_components/popular_components_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def test_option_menu(app: Page):
_expect_no_exception(app)
_expect_iframe_attached(app)

# TODO: uncomment the on_change callback as soon as streamlit-option-menu is updated and uses the new on_change callback
# frame_locator = app.frame_locator("iframe")
# frame_locator.locator("a", has_text="Home").click()
# expect(
# app.get_by_test_id("stMarkdown").filter(has_text="Selection changed to Home")
# ).to_be_visible()


def test_url_fragment(app: Page):
"""Test that the url-fragment component renders"""
Expand Down
17 changes: 15 additions & 2 deletions lib/streamlit/components/types/base_custom_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

import os
from abc import ABC, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Any

from streamlit import util
from streamlit.errors import StreamlitAPIException

if TYPE_CHECKING:
from streamlit.runtime.state.common import WidgetCallback


class MarshallComponentException(StreamlitAPIException):
"""Class for exceptions generated during custom component marshalling."""
Expand Down Expand Up @@ -56,10 +59,17 @@ def __call__(
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""An alias for create_instance."""
return self.create_instance(*args, default=default, key=key, **kwargs)
return self.create_instance(
*args,
default=default,
key=key,
on_change=on_change,
**kwargs,
)

@property
def abspath(self) -> str | None:
Expand Down Expand Up @@ -102,6 +112,7 @@ def create_instance(
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""Create a new instance of the component.
Expand All @@ -118,6 +129,8 @@ def create_instance(
key: str or None
If not None, this is the user key we use to generate the
component's "widget ID".
on_change: WidgetCallback or None
An optional callback invoked when the widget's value changes. No arguments are passed to it.
**kwargs
Keyword args to pass to the component.
Expand Down
14 changes: 13 additions & 1 deletion lib/streamlit/components/v1/custom_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
from streamlit.runtime.state.common import WidgetCallback


class MarshallComponentException(StreamlitAPIException):
Expand All @@ -49,17 +50,25 @@ def __call__(
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""An alias for create_instance."""
return self.create_instance(*args, default=default, key=key, **kwargs)
return self.create_instance(
*args,
default=default,
key=key,
on_change=on_change,
**kwargs,
)

@gather_metrics("create_instance")
def create_instance(
self,
*args,
default: Any = None,
key: str | None = None,
on_change: WidgetCallback | None = None,
**kwargs,
) -> Any:
"""Create a new instance of the component.
Expand All @@ -76,6 +85,8 @@ def create_instance(
key: str or None
If not None, this is the user key we use to generate the
component's "widget ID".
on_change: WidgetCallback or None
An optional callback invoked when the widget's value changes. No arguments are passed to it.
**kwargs
Keyword args to pass to the component.
Expand Down Expand Up @@ -195,6 +206,7 @@ def deserialize_component(ui_value, widget_id=""):
deserializer=deserialize_component,
serializer=lambda x: x,
ctx=ctx,
on_change_handler=on_change,
)
widget_value = component_state.value

Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/runtime/state/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def user_key_from_widget_id(widget_id: str) -> str | None:
"None" as a key, but we can't avoid this kind of problem while storing the
string representation of the no-user-key sentinel as part of the widget id.
"""
user_key = widget_id.split("-", maxsplit=2)[-1]
user_key: str | None = widget_id.split("-", maxsplit=2)[-1]
user_key = None if user_key == "None" else user_key
return user_key

Expand Down
38 changes: 38 additions & 0 deletions lib/tests/streamlit/components_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import json
import os
import threading
import unittest
from typing import Any
from unittest import mock
Expand All @@ -38,9 +39,11 @@
from streamlit.components.v1.custom_component import CustomComponent
from streamlit.errors import DuplicateWidgetID, StreamlitAPIException
from streamlit.proto.Components_pb2 import SpecialArg
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
from streamlit.runtime import Runtime, RuntimeConfig
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.scriptrunner import ScriptRunContext, add_script_run_ctx
from streamlit.type_util import to_bytes
from tests.delta_generator_test_case import DeltaGeneratorTestCase

Expand Down Expand Up @@ -75,6 +78,10 @@ def setUp(self) -> None:
)
self.runtime = Runtime(config)

# declare_component needs a script_run_ctx to be set
self.script_run_ctx = MagicMock(spec=ScriptRunContext)
add_script_run_ctx(threading.current_thread(), self.script_run_ctx)

def tearDown(self) -> None:
Runtime._instance = None

Expand Down Expand Up @@ -484,6 +491,34 @@ def test_df_default(self):
proto.special_args[0],
)

def test_on_change_handler(self):
"""Test the 'on_change' callback param."""

# we use a list here so that we can update it in the lambda; we cannot assign a variable there.
callback_call_value = []
expected_element_value = "Called with foo"

def create_on_change_handler(some_arg: str):
return lambda: callback_call_value.append("Called with " + some_arg)

return_value = self.test_component(
key="key", default="baz", on_change=create_on_change_handler("foo")
)
self.assertEqual("baz", return_value)

proto = self.get_delta_from_queue().new_element.component_instance
self.assertJSONEqual({"key": "key", "default": "baz"}, proto.json_args)
current_widget_states = self.script_run_ctx.session_state.get_widget_states()
new_widget_state = WidgetState()
# copy the custom components state and update the value
new_widget_state.CopyFrom(current_widget_states[0])
# update the widget's value so that the rerun will execute the callback
new_widget_state.json_value = '{"key": "key", "default": "baz2"}'
self.script_run_ctx.session_state.on_script_will_rerun(
WidgetStates(widgets=[new_widget_state])
)
self.assertEqual(callback_call_value[0], expected_element_value)

def assertJSONEqual(self, a, b):
"""Asserts that two JSON dicts are equal. If either arg is a string,
it will be first converted to a dict with json.loads()."""
Expand Down Expand Up @@ -559,6 +594,9 @@ def get_component_path(self, name: str) -> str | None:
def get_module_name(self, name: str) -> str | None:
return None

def get_component(self, name: str) -> BaseCustomComponent | None:
return None

def get_components(self) -> list[BaseCustomComponent]:
return []

Expand Down

0 comments on commit ad3dc4e

Please sign in to comment.