Skip to content

Commit

Permalink
Add improved type parsing capabilities for st.data_editor (streamli…
Browse files Browse the repository at this point in the history
…t#6551)

* Add functionality to check underlying types

* Remove not-implemented types

* Add comment

* Some cleanup

* Add unit test

* Fix unit tests

* Finish unit test

* Add tests for index columns

* Remove type compatibility checks

* Remove refactoring

* Remove changes to column config object

* Remove final import

* Fix test issue

* Add dtype object to empty series for compatibility

* Add negative int and float to test

* Add a couple of comments about column data kind
  • Loading branch information
LukasMasuch authored and zyxue committed Apr 16, 2024
1 parent 47193b1 commit 7870852
Show file tree
Hide file tree
Showing 5 changed files with 837 additions and 102 deletions.
200 changes: 104 additions & 96 deletions lib/streamlit/elements/data_editor.py
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

import contextlib
import json
from dataclasses import dataclass
from typing import (
Expand All @@ -34,11 +33,19 @@

import pandas as pd
import pyarrow as pa
from pandas.api.types import is_datetime64_any_dtype, is_float_dtype, is_integer_dtype
from typing_extensions import Final, Literal, TypeAlias, TypedDict
from typing_extensions import Literal, TypeAlias, TypedDict

from streamlit import logger as _logger
from streamlit import type_util
from streamlit.elements.form import current_form_id
from streamlit.elements.lib.column_config_utils import (
INDEX_IDENTIFIER,
ColumnConfigMapping,
ColumnDataKind,
DataframeSchema,
determine_dataframe_schema,
marshall_column_config,
)
from streamlit.elements.lib.pandas_styler_utils import marshall_styler
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto
Expand All @@ -58,7 +65,7 @@

from streamlit.delta_generator import DeltaGenerator

_INDEX_IDENTIFIER: Final = "index"
_LOGGER = _logger.get_logger("root")

# All formats that support direct editing, meaning that these
# formats will be returned with the same type when used with data_editor.
Expand Down Expand Up @@ -91,25 +98,6 @@
]


class ColumnConfig(TypedDict, total=False):
title: Optional[str]
width: Optional[Literal["small", "medium", "large"]]
hidden: Optional[bool]
disabled: Optional[bool]
required: Optional[bool]
alignment: Optional[Literal["left", "center", "right"]]
type: Optional[
Literal[
"text",
"number",
"checkbox",
"selectbox",
"list",
]
]
type_options: Optional[Dict[str, Any]]


class EditingState(TypedDict, total=False):
"""
A dictionary representing the current state of the data editor.
Expand All @@ -133,44 +121,6 @@ class EditingState(TypedDict, total=False):
deleted_rows: List[int]


# A mapping of column names/IDs to column configs.
ColumnConfigMapping: TypeAlias = Dict[Union[int, str], ColumnConfig]


def _marshall_column_config(
proto: ArrowProto, columns: Optional[Dict[Union[int, str], ColumnConfig]] = None
) -> None:
"""Marshall the column config into the proto.
Parameters
----------
proto : ArrowProto
The proto to marshall into.
columns : Optional[ColumnConfigMapping]
The column config to marshall.
"""
if columns is None:
columns = {}

# Ignore all None values and prefix columns specified by index
def remove_none_values(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
new_dict = {}
for key, val in input_dict.items():
if isinstance(val, dict):
val = remove_none_values(val)
if val is not None:
new_dict[key] = val
return new_dict

proto.columns = json.dumps(
{
(f"col:{str(k)}" if isinstance(k, int) else k): v
for (k, v) in remove_none_values(columns).items()
}
)


@dataclass
class DataEditorSerde:
"""DataEditorSerde is used to serialize and deserialize the data editor state."""
Expand All @@ -190,16 +140,20 @@ def serialize(self, editing_state: EditingState) -> str:
return json.dumps(editing_state, default=str)


def _parse_value(value: Union[str, int, float, bool, None], dtype) -> Any:
def _parse_value(
value: str | int | float | bool | None,
column_data_kind: ColumnDataKind,
) -> Any:
"""Convert a value to the correct type.
Parameters
----------
value : str | int | float | bool | None
The value to convert.
dtype
The type of the value.
column_data_kind : ColumnDataKind
The determined data kind of the column. The column data kind refers to the
shared data type of the values in the column (e.g. integer, float, string).
Returns
-------
Expand All @@ -208,23 +162,53 @@ def _parse_value(value: Union[str, int, float, bool, None], dtype) -> Any:
if value is None:
return None

# TODO(lukasmasuch): how to deal with date & time columns?
try:
if column_data_kind == ColumnDataKind.STRING:
return str(value)

# Datetime values try to parse the value to datetime:
# The value is expected to be a ISO 8601 string
if is_datetime64_any_dtype(dtype):
return pd.to_datetime(value, errors="ignore")
elif is_integer_dtype(dtype):
with contextlib.suppress(ValueError):
if column_data_kind == ColumnDataKind.INTEGER:
return int(value)
elif is_float_dtype(dtype):
with contextlib.suppress(ValueError):

if column_data_kind == ColumnDataKind.FLOAT:
return float(value)

if column_data_kind == ColumnDataKind.BOOLEAN:
return bool(value)

if column_data_kind in [
ColumnDataKind.DATETIME,
ColumnDataKind.DATE,
ColumnDataKind.TIME,
]:
datetime_value = pd.to_datetime(value, utc=False)

if datetime_value is pd.NaT:
return None

if isinstance(datetime_value, pd.Timestamp):
datetime_value = datetime_value.to_pydatetime()

if column_data_kind == ColumnDataKind.DATETIME:
return datetime_value

if column_data_kind == ColumnDataKind.DATE:
return datetime_value.date()

if column_data_kind == ColumnDataKind.TIME:
return datetime_value.time()

except (ValueError, pd.errors.ParserError) as ex:
_LOGGER.warning(
"Failed to parse value %s as %s. Exception: %s", value, column_data_kind, ex
)
return None
return value


def _apply_cell_edits(
df: pd.DataFrame, edited_cells: Mapping[str, str | int | float | bool | None]
df: pd.DataFrame,
edited_cells: Mapping[str, str | int | float | bool | None],
dataframe_schema: DataframeSchema,
) -> None:
"""Apply cell edits to the provided dataframe (inplace).
Expand All @@ -237,6 +221,8 @@ def _apply_cell_edits(
A dictionary of cell edits. The keys are the cell ids in the format
"row:column" and the values are the new cell values.
dataframe_schema: DataframeSchema
The schema of the dataframe.
"""
index_count = df.index.nlevels or 0

Expand All @@ -247,17 +233,21 @@ def _apply_cell_edits(
# The edited cell is part of the index
# To support multi-index in the future: use a tuple of values here
# instead of a single value
df.index.values[row_pos] = _parse_value(value, df.index.dtype)
df.index.values[row_pos] = _parse_value(value, dataframe_schema[col_pos])
else:
# We need to subtract the number of index levels from col_pos
# to get the correct column position for Pandas DataFrames
mapped_column = col_pos - index_count
df.iat[row_pos, mapped_column] = _parse_value(
value, df.iloc[:, mapped_column].dtype
value, dataframe_schema[col_pos]
)


def _apply_row_additions(df: pd.DataFrame, added_rows: List[Dict[str, Any]]) -> None:
def _apply_row_additions(
df: pd.DataFrame,
added_rows: List[Dict[str, Any]],
dataframe_schema: DataframeSchema,
) -> None:
"""Apply row additions to the provided dataframe (inplace).
Parameters
Expand All @@ -268,6 +258,9 @@ def _apply_row_additions(df: pd.DataFrame, added_rows: List[Dict[str, Any]]) ->
added_rows : List[Dict[str, Any]]
A list of row additions. Each row addition is a dictionary with the
column position as key and the new cell value as value.
dataframe_schema: DataframeSchema
The schema of the dataframe.
"""
if not added_rows:
return
Expand All @@ -279,7 +272,7 @@ def _apply_row_additions(df: pd.DataFrame, added_rows: List[Dict[str, Any]]) ->
# combination with loc. As a workaround, we manually track the values here:
range_index_stop = None
range_index_step = None
if type(df.index) == pd.RangeIndex:
if isinstance(df.index, pd.RangeIndex):
range_index_stop = df.index.stop
range_index_step = df.index.step

Expand All @@ -292,14 +285,12 @@ def _apply_row_additions(df: pd.DataFrame, added_rows: List[Dict[str, Any]]) ->
if col_pos < index_count:
# To support multi-index in the future: use a tuple of values here
# instead of a single value
index_value = _parse_value(value, df.index.dtype)
index_value = _parse_value(value, dataframe_schema[col_pos])
else:
# We need to subtract the number of index levels from the col_pos
# to get the correct column position for Pandas DataFrames
mapped_column = col_pos - index_count
new_row[mapped_column] = _parse_value(
value, df.iloc[:, mapped_column].dtype
)
new_row[mapped_column] = _parse_value(value, dataframe_schema[col_pos])
# Append the new row to the dataframe
if range_index_stop is not None:
df.loc[range_index_stop, :] = new_row
Expand Down Expand Up @@ -329,7 +320,11 @@ def _apply_row_deletions(df: pd.DataFrame, deleted_rows: List[int]) -> None:
df.drop(df.index[deleted_rows], inplace=True)


def _apply_dataframe_edits(df: pd.DataFrame, data_editor_state: EditingState) -> None:
def _apply_dataframe_edits(
df: pd.DataFrame,
data_editor_state: EditingState,
dataframe_schema: DataframeSchema,
) -> None:
"""Apply edits to the provided dataframe (inplace).
This includes cell edits, row additions and row deletions.
Expand All @@ -341,12 +336,15 @@ def _apply_dataframe_edits(df: pd.DataFrame, data_editor_state: EditingState) ->
data_editor_state : EditingState
The editing state of the data editor component.
dataframe_schema: DataframeSchema
The schema of the dataframe.
"""
if data_editor_state.get("edited_cells"):
_apply_cell_edits(df, data_editor_state["edited_cells"])
_apply_cell_edits(df, data_editor_state["edited_cells"], dataframe_schema)

if data_editor_state.get("added_rows"):
_apply_row_additions(df, data_editor_state["added_rows"])
_apply_row_additions(df, data_editor_state["added_rows"], dataframe_schema)

if data_editor_state.get("deleted_rows"):
_apply_row_deletions(df, data_editor_state["deleted_rows"])
Expand Down Expand Up @@ -393,9 +391,9 @@ def _apply_data_specific_configs(
DataFormat.LIST_OF_ROWS,
DataFormat.COLUMN_VALUE_MAPPING,
]:
if _INDEX_IDENTIFIER not in columns_config:
columns_config[_INDEX_IDENTIFIER] = {}
columns_config[_INDEX_IDENTIFIER]["hidden"] = True
if INDEX_IDENTIFIER not in columns_config:
columns_config[INDEX_IDENTIFIER] = {}
columns_config[INDEX_IDENTIFIER]["hidden"] = True

# Rename the first column to "value" for some of the data formats
if data_format in [
Expand Down Expand Up @@ -593,13 +591,24 @@ def experimental_data_editor(

# Temporary workaround: We hide range indices if num_rows is dynamic.
# since the current way of handling this index during editing is a bit confusing.
if type(data_df.index) is pd.RangeIndex and num_rows == "dynamic":
if _INDEX_IDENTIFIER not in columns_config:
columns_config[_INDEX_IDENTIFIER] = {}
columns_config[_INDEX_IDENTIFIER]["hidden"] = True
if isinstance(data_df.index, pd.RangeIndex) and num_rows == "dynamic":
if INDEX_IDENTIFIER not in columns_config:
columns_config[INDEX_IDENTIFIER] = {}
columns_config[INDEX_IDENTIFIER]["hidden"] = True

# Convert the dataframe to an arrow table which is used as the main
# serialization format for sending the data to the frontend.
# We also utilize the arrow schema to determine the data kinds of every column.
arrow_table = pa.Table.from_pandas(data_df)

# Determine the dataframe schema which is required for parsing edited values
# and for checking type compatibilities.
dataframe_schema = determine_dataframe_schema(data_df, arrow_table.schema)

proto = ArrowProto()

proto.use_container_width = use_container_width

if width:
proto.width = width
if height:
Expand All @@ -619,10 +628,9 @@ def experimental_data_editor(
default_uuid = str(hash(delta_path))
marshall_styler(proto, data, default_uuid)

table = pa.Table.from_pandas(data_df)
proto.data = type_util.pyarrow_table_to_bytes(table)
proto.data = type_util.pyarrow_table_to_bytes(arrow_table)

_marshall_column_config(proto, columns_config)
marshall_column_config(proto, columns_config)

serde = DataEditorSerde()

Expand All @@ -638,7 +646,7 @@ def experimental_data_editor(
ctx=get_script_run_ctx(),
)

_apply_dataframe_edits(data_df, widget_state.value)
_apply_dataframe_edits(data_df, widget_state.value, dataframe_schema)
self.dg._enqueue("arrow_data_frame", proto)
return type_util.convert_df_to_data_format(data_df, data_format)

Expand Down

0 comments on commit 7870852

Please sign in to comment.