Skip to content

Commit

Permalink
Refactor vega-related charting (#8595)
Browse files Browse the repository at this point in the history
## Describe your changes

This PR refactors/restructures our backend logic that handles all
vega-related charting: `st.vega_lite_chart`, `st.altair_chart`,
`st.line_chart`, `st.area_chart`, `st.bar_chart`, `st.scatter_chart`.

- The built-in charting logic is migrated to `built_in_chart_utils` 
- All commands are migrated to `vega_charts` and refactored to reuse
more shared logic.
- This also applies some refactoring to how we handle adding `add_rows`
metadata by just relaying on the available metadata instead of checking
for the delta type.

## Testing Plan

- Updated unit tests
- No big logical changes.

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
LukasMasuch committed May 3, 2024
1 parent d93a282 commit 67d0b04
Show file tree
Hide file tree
Showing 8 changed files with 1,483 additions and 1,463 deletions.
98 changes: 22 additions & 76 deletions lib/streamlit/delta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@
)
from streamlit.cursor import Cursor
from streamlit.elements.alert import AlertMixin
from streamlit.elements.altair_utils import AddRowsMetadata
from streamlit.elements.arrow import ArrowMixin
from streamlit.elements.arrow_altair import ArrowAltairMixin, prep_data
from streamlit.elements.arrow_vega_lite import ArrowVegaLiteMixin
from streamlit.elements.balloons import BalloonsMixin
from streamlit.elements.bokeh_chart import BokehMixin
from streamlit.elements.code import CodeMixin
Expand All @@ -75,6 +72,7 @@
from streamlit.elements.snow import SnowMixin
from streamlit.elements.text import TextMixin
from streamlit.elements.toast import ToastMixin
from streamlit.elements.vega_charts import VegaChartsMixin
from streamlit.elements.widgets.button import ButtonMixin
from streamlit.elements.widgets.camera_input import CameraInputMixin
from streamlit.elements.widgets.chat import ChatMixin
Expand All @@ -94,7 +92,7 @@
from streamlit.errors import NoSessionContext, StreamlitAPIException
from streamlit.proto import Block_pb2, ForwardMsg_pb2
from streamlit.proto.RootContainer_pb2 import RootContainer
from streamlit.runtime import caching, legacy_caching
from streamlit.runtime import caching
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.runtime.state import NoValue

Expand All @@ -104,19 +102,11 @@
from pandas import DataFrame, Series

from streamlit.elements.arrow import Data
from streamlit.elements.lib.built_in_chart_utils import AddRowsMetadata


MAX_DELTA_BYTES: Final[int] = 14 * 1024 * 1024 # 14MB

# List of Streamlit commands that perform a Pandas "melt" operation on
# input dataframes:
ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES: Final = (
"arrow_line_chart",
"arrow_area_chart",
"arrow_bar_chart",
"arrow_scatter_chart",
)

Value = TypeVar("Value")
DG = TypeVar("DG", bound="DeltaGenerator")

Expand Down Expand Up @@ -197,8 +187,7 @@ class DeltaGenerator(
ToastMixin,
WriteMixin,
ArrowMixin,
ArrowAltairMixin,
ArrowVegaLiteMixin,
VegaChartsMixin,
DataEditorMixin,
):
"""Creator of Delta protobuf messages.
Expand Down Expand Up @@ -531,18 +520,9 @@ def _enqueue(
# Warn if an element is being changed but the user isn't running the streamlit server.
_maybe_print_use_warning()

# Some elements have a method.__name__ != delta_type in proto.
# This really matters for line_chart, bar_chart & area_chart,
# since add_rows() relies on method.__name__ == delta_type
# TODO: Fix for all elements (or the cache warning above will be wrong)
proto_type = delta_type

if proto_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES:
proto_type = "arrow_vega_lite_chart"

# Copy the marshalled proto into the overall msg proto
msg = ForwardMsg_pb2.ForwardMsg()
msg_el_proto = getattr(msg.delta.new_element, proto_type)
msg_el_proto = getattr(msg.delta.new_element, delta_type)
msg_el_proto.CopyFrom(element_proto)

# Only enqueue message and fill in metadata if there's a container.
Expand Down Expand Up @@ -732,22 +712,20 @@ def _arrow_add_rows(

# When doing _arrow_add_rows on an element that does not already have data
# (for example, st.line_chart() without any args), call the original
# st._arrow_foo() element with new data instead of doing a _arrow_add_rows().
# st.foo() element with new data instead of doing a _arrow_add_rows().
if (
self._cursor.props["delta_type"] in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
"add_rows_metadata" in self._cursor.props
and self._cursor.props["add_rows_metadata"]
and self._cursor.props["add_rows_metadata"].last_index is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
# delta_type starts with "arrow_", but st_method_name doesn't use this prefix.
st_method_name = self._cursor.props["delta_type"].replace("arrow_", "")
st_method = getattr(self, st_method_name)
st_method = getattr(
self, self._cursor.props["add_rows_metadata"].chart_command
)
st_method(data, **kwargs)
return None

new_data, self._cursor.props["add_rows_metadata"] = _prep_data_for_add_rows(
data,
self._cursor.props["delta_type"],
self._cursor.props["add_rows_metadata"],
)

Expand Down Expand Up @@ -797,53 +775,21 @@ def get_last_dg_added_to_context_stack() -> DeltaGenerator | None:

def _prep_data_for_add_rows(
data: Data,
delta_type: str,
add_rows_metadata: AddRowsMetadata,
) -> tuple[Data, AddRowsMetadata]:
out_data: Data

# For some delta types we have to reshape the data structure
# otherwise the input data and the actual data used
# by vega_lite will be different, and it will throw an error.
if delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES:
import pandas as pd

df = cast(pd.DataFrame, type_util.convert_anything_to_df(data))

# Make range indices start at last_index.
if isinstance(df.index, pd.RangeIndex):
old_step = _get_pandas_index_attr(df, "step")

# We have to drop the predefined index
df = df.reset_index(drop=True)

old_stop = _get_pandas_index_attr(df, "stop")

if old_step is None or old_stop is None:
raise StreamlitAPIException(
"'RangeIndex' object has no attribute 'step'"
)

start = add_rows_metadata.last_index + old_step
stop = add_rows_metadata.last_index + old_step + old_stop

df.index = pd.RangeIndex(start=start, stop=stop, step=old_step)
add_rows_metadata.last_index = stop - 1

out_data, *_ = prep_data(df, **add_rows_metadata.columns)

else:
add_rows_metadata: AddRowsMetadata | None,
) -> tuple[Data, AddRowsMetadata | None]:
if not add_rows_metadata:
# When calling add_rows on st.table or st.dataframe we want styles to pass through.
out_data = type_util.convert_anything_to_df(data, allow_styler=True)
return type_util.convert_anything_to_df(data, allow_styler=True), None

return out_data, add_rows_metadata
# If add_rows_metadata is set, it indicates that the add_rows used called
# on a chart based on our built-in chart commands.

# For built-in chart commands we have to reshape the data structure
# otherwise the input data and the actual data used
# by vega_lite will be different, and it will throw an error.
from streamlit.elements.lib.built_in_chart_utils import prep_chart_data_for_add_rows

def _get_pandas_index_attr(
data: DataFrame | Series,
attr: str,
) -> Any | None:
return getattr(data.index, attr, None)
return prep_chart_data_for_add_rows(data, add_rows_metadata)


@overload
Expand Down
40 changes: 0 additions & 40 deletions lib/streamlit/elements/altair_utils.py

This file was deleted.

Loading

0 comments on commit 67d0b04

Please sign in to comment.