Skip to content

Commit

Permalink
Add color support to charts, and refactor things a bit.
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tteixeira committed Jul 16, 2023
1 parent b3c9b81 commit 299d047
Show file tree
Hide file tree
Showing 11 changed files with 1,869 additions and 313 deletions.
98 changes: 53 additions & 45 deletions lib/streamlit/delta_generator.py
Expand Up @@ -24,6 +24,7 @@
Hashable,
Iterable,
NoReturn,
Optional,
Type,
TypeVar,
cast,
Expand All @@ -36,11 +37,12 @@
from streamlit import config, cursor, env_util, logger, runtime, type_util, util
from streamlit.cursor import Cursor
from streamlit.elements.alert import AlertMixin
from streamlit.elements.altair_utils import AddRowsMetadata

# DataFrame elements come in two flavors: "Legacy" and "Arrow".
# We select between them with the DataFrameElementSelectorMixin.
from streamlit.elements.arrow import ArrowMixin
from streamlit.elements.arrow_altair import ArrowAltairMixin
from streamlit.elements.arrow_altair import ArrowAltairMixin, prep_data
from streamlit.elements.arrow_vega_lite import ArrowVegaLiteMixin
from streamlit.elements.balloons import BalloonsMixin
from streamlit.elements.bokeh_chart import BokehMixin
Expand All @@ -64,7 +66,7 @@
from streamlit.elements.image import ImageMixin
from streamlit.elements.json import JsonMixin
from streamlit.elements.layouts import LayoutsMixin
from streamlit.elements.legacy_altair import LegacyAltairMixin
from streamlit.elements.legacy_altair import ArrowNotSupportedError, LegacyAltairMixin
from streamlit.elements.legacy_data_frame import LegacyDataFrameMixin
from streamlit.elements.legacy_vega_lite import LegacyVegaLiteMixin
from streamlit.elements.map import MapMixin
Expand Down Expand Up @@ -113,10 +115,12 @@
"arrow_line_chart",
"arrow_area_chart",
"arrow_bar_chart",
"arrow_scatter_chart",
)

Value = TypeVar("Value")
DG = TypeVar("DG", bound="DeltaGenerator")
DFT = TypeVar("DFT", bound=type_util.DataFrameCompatible)

# Type aliases for Parent Block Types
BlockType = str
Expand Down Expand Up @@ -417,7 +421,7 @@ def _enqueue( # type: ignore[misc]
delta_type: str,
element_proto: Message,
return_value: None,
last_index: Hashable | None = None,
add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator:
Expand All @@ -429,7 +433,7 @@ def _enqueue( # type: ignore[misc]
delta_type: str,
element_proto: Message,
return_value: Type[NoValue],
last_index: Hashable | None = None,
add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> None:
Expand All @@ -441,7 +445,7 @@ def _enqueue( # type: ignore[misc]
delta_type: str,
element_proto: Message,
return_value: Value,
last_index: Hashable | None = None,
add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> Value:
Expand All @@ -453,7 +457,7 @@ def _enqueue(
delta_type: str,
element_proto: Message,
return_value: None = None,
last_index: Hashable | None = None,
add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator:
Expand All @@ -465,7 +469,7 @@ def _enqueue(
delta_type: str,
element_proto: Message,
return_value: Type[NoValue] | Value | None = None,
last_index: Hashable | None = None,
add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator | Value | None:
Expand All @@ -476,7 +480,7 @@ def _enqueue(
delta_type: str,
element_proto: Message,
return_value: Type[NoValue] | Value | None = None,
last_index: Hashable | None = None,
add_rows_metadata: Optional[AddRowsMetadata] = None,
element_width: int | None = None,
element_height: int | None = None,
) -> DeltaGenerator | Value | None:
Expand Down Expand Up @@ -549,7 +553,7 @@ def _enqueue(
# position.
new_cursor = (
dg._cursor.get_locked_cursor(
delta_type=delta_type, last_index=last_index
delta_type=delta_type, add_rows_metadata=add_rows_metadata
)
if dg._cursor is not None
else None
Expand Down Expand Up @@ -637,7 +641,7 @@ def _block(
block_dg._form_data = FormData(current_form_id(dg))

# Must be called to increment this cursor's index.
dg._cursor.get_locked_cursor(last_index=None)
dg._cursor.get_locked_cursor(add_rows_metadata=None)
_enqueue_message(msg)

caching.save_block_message(
Expand Down Expand Up @@ -732,12 +736,16 @@ def _legacy_add_rows(
"Command requires exactly one dataset"
)

# The legacy add_rows does not support Arrow tables.
if type_util.is_type(data, "pyarrow.lib.Table"):
raise ArrowNotSupportedError()

# When doing _legacy_add_rows on an element that does not already have data
# (for example, st._legacy_line_chart() without any args), call the original
# st._legacy_foo() element with new data instead of doing a _legacy_add_rows().
if (
self._cursor.props["delta_type"] in DELTA_TYPES_THAT_MELT_DATAFRAMES
and self._cursor.props["last_index"] is None
and self._cursor.props["add_rows_metadata"].last_index is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
Expand All @@ -747,8 +755,11 @@ def _legacy_add_rows(
st_method(data, **kwargs)
return None

data, self._cursor.props["last_index"] = _maybe_melt_data_for_add_rows(
data, self._cursor.props["delta_type"], self._cursor.props["last_index"]
data, self._cursor.props["add_rows_metadata"] = _prep_data_for_add_rows(
data,
self._cursor.props["delta_type"],
self._cursor.props["add_rows_metadata"],
is_legacy=True,
)

msg = ForwardMsg_pb2.ForwardMsg()
Expand Down Expand Up @@ -853,7 +864,7 @@ def _arrow_add_rows(
# st._arrow_foo() element with new data instead of doing a _arrow_add_rows().
if (
self._cursor.props["delta_type"] in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
and self._cursor.props["last_index"] is None
and self._cursor.props["add_rows_metadata"].last_index is None
):
# IMPORTANT: This assumes delta types and st method names always
# match!
Expand All @@ -863,8 +874,11 @@ def _arrow_add_rows(
st_method(data, **kwargs)
return None

data, self._cursor.props["last_index"] = _maybe_melt_data_for_add_rows(
data, self._cursor.props["delta_type"], self._cursor.props["last_index"]
data, self._cursor.props["add_rows_metadata"] = _prep_data_for_add_rows(
data,
self._cursor.props["delta_type"],
self._cursor.props["add_rows_metadata"],
is_legacy=False,
)

msg = ForwardMsg_pb2.ForwardMsg()
Expand All @@ -884,17 +898,24 @@ def _arrow_add_rows(
return self


DFT = TypeVar("DFT", bound=type_util.DataFrameCompatible)


def _maybe_melt_data_for_add_rows(
def _prep_data_for_add_rows(
data: DFT,
delta_type: str,
last_index: Any,
add_rows_metadata: AddRowsMetadata,
is_legacy: bool,
) -> tuple[DFT | DataFrame, int | Any]:
import pandas as pd

def _melt_data(df: DataFrame, last_index: Any) -> tuple[DataFrame, int | Any]:
df = cast(pd.DataFrame, type_util.convert_anything_to_df(data, allow_styler=True))

# For some delta types we have to reshape the data structure
# otherwise the input data and the actual data used
# by vega_lite will be different, and it will throw an error.
if (
delta_type in DELTA_TYPES_THAT_MELT_DATAFRAMES
or delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
):
# Make range indices start at last_index.
if isinstance(df.index, pd.RangeIndex):
old_step = _get_pandas_index_attr(df, "step")

Expand All @@ -908,35 +929,22 @@ def _melt_data(df: DataFrame, last_index: Any) -> tuple[DataFrame, int | Any]:
"'RangeIndex' object has no attribute 'step'"
)

start = last_index + old_step
stop = last_index + old_step + old_stop
start = add_rows_metadata.last_index + old_step
stop = add_rows_metadata.last_index + old_step + old_stop

df.index = pd.RangeIndex(start=start, stop=stop, step=old_step)
last_index = stop - 1

index_name = df.index.name
if index_name is None:
index_name = "index"
add_rows_metadata.last_index = stop - 1

df = pd.melt(df.reset_index(), id_vars=[index_name])
return df, last_index
if is_legacy:
index_name = df.index.name
if index_name is None:
index_name = "index"

# For some delta types we have to reshape the data structure
# otherwise the input data and the actual data used
# by vega_lite will be different, and it will throw an error.
if (
delta_type in DELTA_TYPES_THAT_MELT_DATAFRAMES
or delta_type in ARROW_DELTA_TYPES_THAT_MELT_DATAFRAMES
):
if not isinstance(data, pd.DataFrame):
return _melt_data(
df=type_util.convert_anything_to_df(data),
last_index=last_index,
)
df = pd.melt(df.reset_index(), id_vars=[index_name])
else:
return _melt_data(df=data, last_index=last_index)
df, *_ = prep_data(df, **add_rows_metadata.columns)

return data, last_index
return df, add_rows_metadata


def _get_pandas_index_attr(
Expand Down
37 changes: 37 additions & 0 deletions lib/streamlit/elements/altair_utils.py
@@ -0,0 +1,37 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Useful classes for our native Altair-based charts.
These classes are used in both Arrow-based and legacy-based charting code to pass some
important info to add_rows.
"""

from dataclasses import dataclass
from typing import Hashable, List, Optional, TypedDict


class PrepDataColumns(TypedDict):
"""Columns used for the prep_data step in Altair Arrow charts."""

x_column: Optional[str]
wide_y_columns: List[str]
color_column: Optional[str]


@dataclass
class AddRowsMetadata:
"""Metadata needed by add_rows on native charts."""

last_index: Optional[Hashable]
columns: PrepDataColumns

0 comments on commit 299d047

Please sign in to comment.