Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor vega-related charting #8595

Merged
merged 23 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 22 additions & 76 deletions lib/streamlit/delta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@
)
from streamlit.cursor import Cursor
from streamlit.elements.alert import AlertMixin
from streamlit.elements.altair_utils import AddRowsMetadata
from streamlit.elements.arrow import ArrowMixin
from streamlit.elements.arrow_altair import ArrowAltairMixin, prep_data
from streamlit.elements.arrow_vega_lite import ArrowVegaLiteMixin
from streamlit.elements.balloons import BalloonsMixin
from streamlit.elements.bokeh_chart import BokehMixin
from streamlit.elements.code import CodeMixin
Expand All @@ -75,6 +72,7 @@
from streamlit.elements.snow import SnowMixin
from streamlit.elements.text import TextMixin
from streamlit.elements.toast import ToastMixin
from streamlit.elements.vega_charts import VegaChartsMixin
Fixed Show fixed Hide fixed
from streamlit.elements.widgets.button import ButtonMixin
from streamlit.elements.widgets.camera_input import CameraInputMixin
from streamlit.elements.widgets.chat import ChatMixin
Expand All @@ -94,7 +92,7 @@
from streamlit.errors import NoSessionContext, StreamlitAPIException
from streamlit.proto import Block_pb2, ForwardMsg_pb2
from streamlit.proto.RootContainer_pb2 import RootContainer
from streamlit.runtime import caching, legacy_caching
from streamlit.runtime import caching
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.runtime.state import NoValue

Expand All @@ -104,19 +102,11 @@
from pandas import DataFrame, Series

from streamlit.elements.arrow import Data
from streamlit.elements.lib.built_in_chart_utils import AddRowsMetadata


MAX_DELTA_BYTES: Final[int] = 14 * 1024 * 1024 # 14MB

# List of Streamlit commands that perform a Pandas "melt" operation on
# input dataframes:
ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES: Final = (
"arrow_line_chart",
"arrow_area_chart",
"arrow_bar_chart",
"arrow_scatter_chart",
)

Value = TypeVar("Value")
DG = TypeVar("DG", bound="DeltaGenerator")

Expand Down Expand Up @@ -197,8 +187,7 @@ class DeltaGenerator(
ToastMixin,
WriteMixin,
ArrowMixin,
ArrowAltairMixin,
ArrowVegaLiteMixin,
VegaChartsMixin,
DataEditorMixin,
):
"""Creator of Delta protobuf messages.
Expand Down Expand Up @@ -531,18 +520,9 @@ def _enqueue(
# Warn if an element is being changed but the user isn't running the streamlit server.
_maybe_print_use_warning()

# Some elements have a method.__name__ != delta_type in proto.
# This really matters for line_chart, bar_chart & area_chart,
# since add_rows() relies on method.__name__ == delta_type
# TODO: Fix for all elements (or the cache warning above will be wrong)
proto_type = delta_type

if proto_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES:
proto_type = "arrow_vega_lite_chart"

LukasMasuch marked this conversation as resolved.
Show resolved Hide resolved
# Copy the marshalled proto into the overall msg proto
msg = ForwardMsg_pb2.ForwardMsg()
msg_el_proto = getattr(msg.delta.new_element, proto_type)
msg_el_proto = getattr(msg.delta.new_element, delta_type)
msg_el_proto.CopyFrom(element_proto)

# Only enqueue message and fill in metadata if there's a container.
Expand Down Expand Up @@ -732,22 +712,20 @@ def _arrow_add_rows(

# When doing _arrow_add_rows on an element that does not already have data
# (for example, st.line_chart() without any args), call the original
# st._arrow_foo() element with new data instead of doing a _arrow_add_rows().
# st.foo() element with new data instead of doing a _arrow_add_rows().
if (
self._cursor.props["delta_type"] in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
"add_rows_metadata" in self._cursor.props
and self._cursor.props["add_rows_metadata"]
and self._cursor.props["add_rows_metadata"].last_index is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
# delta_type starts with "arrow_", but st_method_name doesn't use this prefix.
st_method_name = self._cursor.props["delta_type"].replace("arrow_", "")
st_method = getattr(self, st_method_name)
st_method = getattr(
self, self._cursor.props["add_rows_metadata"].chart_command
)
st_method(data, **kwargs)
return None

new_data, self._cursor.props["add_rows_metadata"] = _prep_data_for_add_rows(
data,
self._cursor.props["delta_type"],
self._cursor.props["add_rows_metadata"],
)

Expand Down Expand Up @@ -797,53 +775,21 @@ def get_last_dg_added_to_context_stack() -> DeltaGenerator | None:

def _prep_data_for_add_rows(
data: Data,
delta_type: str,
add_rows_metadata: AddRowsMetadata,
) -> tuple[Data, AddRowsMetadata]:
out_data: Data

# For some delta types we have to reshape the data structure
# otherwise the input data and the actual data used
# by vega_lite will be different, and it will throw an error.
if delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES:
import pandas as pd

df = cast(pd.DataFrame, type_util.convert_anything_to_df(data))

# Make range indices start at last_index.
if isinstance(df.index, pd.RangeIndex):
old_step = _get_pandas_index_attr(df, "step")

# We have to drop the predefined index
df = df.reset_index(drop=True)

old_stop = _get_pandas_index_attr(df, "stop")

if old_step is None or old_stop is None:
raise StreamlitAPIException(
"'RangeIndex' object has no attribute 'step'"
)

start = add_rows_metadata.last_index + old_step
stop = add_rows_metadata.last_index + old_step + old_stop

df.index = pd.RangeIndex(start=start, stop=stop, step=old_step)
add_rows_metadata.last_index = stop - 1

out_data, *_ = prep_data(df, **add_rows_metadata.columns)

else:
add_rows_metadata: AddRowsMetadata | None,
) -> tuple[Data, AddRowsMetadata | None]:
if not add_rows_metadata:
# When calling add_rows on st.table or st.dataframe we want styles to pass through.
out_data = type_util.convert_anything_to_df(data, allow_styler=True)
return type_util.convert_anything_to_df(data, allow_styler=True), None

return out_data, add_rows_metadata
# If add_rows_metadata is set, it indicates that the add_rows used called
# on a chart based on our built-in chart commands.

# For built-in chart commands we have to reshape the data structure
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if this function is only relevant for chart commands but not generally for everything that uses arrow / add-row functionality as the comment says, I suggest to rename the function to sth. like _prep_chart_data_for_add_rows and perhaps move it to the new built_in_chart_utils.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot easily fully move this method since its used in all add row cases, but I extracted the chart specific logic to built_in_charts_utils.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alrighty, but is this function only relevant for chart types or in general?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The migrated part is only relevant for built-in charts, but _prep_data_for_add_rows is called in all add row cases.

# otherwise the input data and the actual data used
# by vega_lite will be different, and it will throw an error.
from streamlit.elements.lib.built_in_chart_utils import prep_chart_data_for_add_rows

def _get_pandas_index_attr(
data: DataFrame | Series,
attr: str,
) -> Any | None:
return getattr(data.index, attr, None)
return prep_chart_data_for_add_rows(data, add_rows_metadata)


@overload
Expand Down
40 changes: 0 additions & 40 deletions lib/streamlit/elements/altair_utils.py

This file was deleted.

Loading
Loading