Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move check_* utils to own file #8764

Merged
merged 7 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/streamlit/components/v1/custom_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from streamlit import _main, type_util
from streamlit.components.types.base_custom_component import BaseCustomComponent
from streamlit.elements.form import current_form_id
from streamlit.elements.utils import check_cache_replay_rules
from streamlit.elements.lib.policies import check_cache_replay_rules
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Components_pb2 import ArrowTable as ArrowTableProto
from streamlit.proto.Components_pb2 import SpecialArg
Expand Down
13 changes: 5 additions & 8 deletions lib/streamlit/elements/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
)
from streamlit.elements.lib.event_utils import AttributeDictionary
from streamlit.elements.lib.pandas_styler_utils import marshall_styler
from streamlit.elements.lib.policies import (
check_cache_replay_rules,
check_callback_rules,
check_session_state_rules,
)
Dismissed Show dismissed Hide dismissed
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto
from streamlit.runtime.metrics_util import gather_metrics
Expand Down Expand Up @@ -476,14 +481,6 @@

if is_selection_activated:
# Run some checks that are only relevant when selections are activated

# Import here to avoid circular imports
from streamlit.elements.utils import (
check_cache_replay_rules,
check_callback_rules,
check_session_state_rules,
)

check_cache_replay_rules()
if callable(on_select):
check_callback_rules(self.dg, on_select)
Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/elements/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@

"""
# Import this here to avoid circular imports.
from streamlit.elements.utils import (
from streamlit.elements.lib.policies import (
check_cache_replay_rules,
check_session_state_rules,
)

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
streamlit.elements.lib.policies
begins an import cycle.

if is_in_form(self.dg):
raise StreamlitAPIException("Forms cannot be nested in other forms.")
Expand Down
122 changes: 122 additions & 0 deletions lib/streamlit/elements/lib/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2024)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from streamlit import config, runtime
from streamlit.elements.form import is_in_form
Dismissed Show dismissed Hide dismissed
from streamlit.errors import StreamlitAPIException, StreamlitAPIWarning
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.state import WidgetCallback, get_session_state

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator


def check_callback_rules(dg: DeltaGenerator, on_change: WidgetCallback | None) -> None:
"""Ensures that widgets other than `st.form_submit` within a form don't have an on_change callback set.

Raises
------
StreamlitAPIException:
Raised when the described rule is violated.
"""

if runtime.exists() and is_in_form(dg) and on_change is not None:
raise StreamlitAPIException(
"With forms, callbacks can only be defined on the `st.form_submit_button`."
" Defining callbacks on other widgets inside a form is not allowed."
)


_shown_default_value_warning: bool = False


def check_session_state_rules(
default_value: Any, key: str | None, writes_allowed: bool = True
) -> None:
"""Ensures that no values are set for widgets with the given key when writing is not allowed.

Additionally, if `global.disableWidgetStateDuplicationWarning` is False a warning is shown when a widget has a default value but its value is also set via session state.

Raises
------
StreamlitAPIException:
Raised when the described rule is violated.
"""
global _shown_default_value_warning

if key is None or not runtime.exists():
return

session_state = get_session_state()
if not session_state.is_new_state_value(key):
return

if not writes_allowed:
raise StreamlitAPIException(
f'Values for the widget with key "{key}" cannot be set using `st.session_state`.'
)

if (
default_value is not None
and not _shown_default_value_warning
and not config.get_option("global.disableWidgetStateDuplicationWarning")
):
from streamlit import warning

warning(
f'The widget with key "{key}" was created with a default value but'
" also had its value set via the Session State API."
)
_shown_default_value_warning = True
Dismissed Show dismissed Hide dismissed


class CachedWidgetWarning(StreamlitAPIWarning):
def __init__(self):
super().__init__(
"""
Your script uses a widget command in a cached function
(function decorated with `@st.cache_data` or `@st.cache_resource`).
This code will only be called when we detect a cache "miss",
which can lead to unexpected results.

How to fix this:
* Move all widget commands outside the cached function.
* Or, if you know what you're doing, use `experimental_allow_widgets=True`
in the cache decorator to enable widget replay and suppress this warning.
"""
)


def check_cache_replay_rules() -> None:
"""Check if a widget is allowed to be used in the current context.
More specifically, this checks if the current context is inside a
cached function that disallows widget usage. If so, it raises a warning.

If there are other similar checks in the future, we could extend this
function to check for those as well. And rename it to check_widget_usage_rules.
"""
if runtime.exists():
# from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx

ctx = get_script_run_ctx()
if ctx and ctx.disallow_cached_widget_usage:
from streamlit import exception

# We use an exception here to show a proper stack trace
# that indicates to the user where the issue is.
exception(CachedWidgetWarning())
Original file line number Diff line number Diff line change
Expand Up @@ -15,95 +15,13 @@
from __future__ import annotations

from enum import Enum, EnumMeta
from typing import TYPE_CHECKING, Any, Iterable, Sequence, overload
from typing import Any, Iterable, Sequence, overload

import streamlit
from streamlit import config, runtime, type_util
from streamlit.elements.form import is_in_form
from streamlit.errors import StreamlitAPIException, StreamlitAPIWarning
from streamlit import type_util
from streamlit.proto.LabelVisibilityMessage_pb2 import LabelVisibilityMessage
from streamlit.runtime.state import WidgetCallback, get_session_state
from streamlit.runtime.state.common import RegisterWidgetResult
from streamlit.type_util import T

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator


def check_callback_rules(dg: DeltaGenerator, on_change: WidgetCallback | None) -> None:
if runtime.exists() and is_in_form(dg) and on_change is not None:
raise StreamlitAPIException(
"With forms, callbacks can only be defined on the `st.form_submit_button`."
" Defining callbacks on other widgets inside a form is not allowed."
)


_shown_default_value_warning: bool = False


def check_session_state_rules(
default_value: Any, key: str | None, writes_allowed: bool = True
) -> None:
global _shown_default_value_warning

if key is None or not runtime.exists():
return

session_state = get_session_state()
if not session_state.is_new_state_value(key):
return

if not writes_allowed:
raise StreamlitAPIException(
f'Values for the widget with key "{key}" cannot be set using `st.session_state`.'
)

if (
default_value is not None
and not _shown_default_value_warning
and not config.get_option("global.disableWidgetStateDuplicationWarning")
):
streamlit.warning(
f'The widget with key "{key}" was created with a default value but'
" also had its value set via the Session State API."
)
_shown_default_value_warning = True


class CachedWidgetWarning(StreamlitAPIWarning):
def __init__(self):
super().__init__(
"""
Your script uses a widget command in a cached function
(function decorated with `@st.cache_data` or `@st.cache_resource`).
This code will only be called when we detect a cache "miss",
which can lead to unexpected results.

How to fix this:
* Move all widget commands outside the cached function.
* Or, if you know what you're doing, use `experimental_allow_widgets=True`
in the cache decorator to enable widget replay and suppress this warning.
"""
)


def check_cache_replay_rules() -> None:
"""Check if a widget is allowed to be used in the current context.
More specifically, this checks if the current context is inside a
cached function that disallows widget usage. If so, it raises a warning.

If there are other similar checks in the future, we could extend this
function to check for those as well. And rename it to check_widget_usage_rules.
"""
if runtime.exists():
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx

ctx = get_script_run_ctx()
if ctx and ctx.disallow_cached_widget_usage:
# We use an exception here to show a proper stack trace
# that indicates to the user where the issue is.
streamlit.exception(CachedWidgetWarning())


def get_label_visibility_proto_value(
label_visibility_string: type_util.LabelVisibility,
Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/elements/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from typing_extensions import TypeAlias

from streamlit.elements.utils import get_label_visibility_proto_value
from streamlit.elements.lib.utils import get_label_visibility_proto_value
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Metric_pb2 import Metric as MetricProto
from streamlit.runtime.metrics_util import gather_metrics
Expand Down
31 changes: 20 additions & 11 deletions lib/streamlit/elements/plotly_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.elements.form import current_form_id
from streamlit.elements.lib.event_utils import AttributeDictionary
from streamlit.elements.lib.policies import (
check_cache_replay_rules,
check_callback_rules,
check_session_state_rules,
)
Comment on lines +41 to +45

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
streamlit.elements.lib.policies
begins an import cycle.
from streamlit.elements.lib.streamlit_plotly_theme import (
configure_streamlit_plotly_theme,
)
Expand Down Expand Up @@ -274,7 +279,11 @@
key: Key | None = None,
on_select: Literal["ignore"], # No default value here to make it work with mypy
selection_mode: SelectionMode
| Iterable[SelectionMode] = ("points", "box", "lasso"),
| Iterable[SelectionMode] = (
"points",
"box",
"lasso",
),
**kwargs: Any,
) -> DeltaGenerator:
...
Expand All @@ -289,7 +298,11 @@
key: Key | None = None,
on_select: Literal["rerun"] | WidgetCallback = "rerun",
selection_mode: SelectionMode
| Iterable[SelectionMode] = ("points", "box", "lasso"),
| Iterable[SelectionMode] = (
"points",
"box",
"lasso",
),
**kwargs: Any,
) -> PlotlyState:
...
Expand All @@ -304,7 +317,11 @@
key: Key | None = None,
on_select: Literal["rerun", "ignore"] | WidgetCallback = "ignore",
selection_mode: SelectionMode
| Iterable[SelectionMode] = ("points", "box", "lasso"),
| Iterable[SelectionMode] = (
"points",
"box",
"lasso",
),
**kwargs: Any,
) -> DeltaGenerator | PlotlyState:
"""Display an interactive Plotly chart.
Expand Down Expand Up @@ -450,14 +467,6 @@

if is_selection_activated:
# Run some checks that are only relevant when selections are activated

# Import here to avoid circular imports
from streamlit.elements.utils import (
check_cache_replay_rules,
check_callback_rules,
check_session_state_rules,
)

check_cache_replay_rules()
if callable(on_select):
check_callback_rules(self.dg, on_select)
Expand Down
12 changes: 5 additions & 7 deletions lib/streamlit/elements/vega_charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
generate_chart,
)
from streamlit.elements.lib.event_utils import AttributeDictionary
from streamlit.elements.lib.policies import (
check_cache_replay_rules,
check_callback_rules,
check_session_state_rules,
)
Comment on lines +46 to +50

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
streamlit.elements.lib.policies
begins an import cycle.
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ArrowVegaLiteChart_pb2 import (
ArrowVegaLiteChart as ArrowVegaLiteChartProto,
Expand Down Expand Up @@ -1651,13 +1656,6 @@
if is_selection_activated:
# Run some checks that are only relevant when selections are activated

# Import here to avoid circular imports
from streamlit.elements.utils import (
check_cache_replay_rules,
check_callback_rules,
check_session_state_rules,
)

check_cache_replay_rules()
if callable(on_select):
check_callback_rules(self.dg, on_select)
Expand Down
Loading
Loading