Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowpark integration: Add support for uncollected dataframes in st.map and improve st.table data collection #5590

Merged
merged 8 commits into from Oct 25, 2022
2 changes: 2 additions & 0 deletions e2e/scripts/st_arrow_unevaluated_snowpark_dataframe.py
Expand Up @@ -24,3 +24,5 @@
st.bar_chart(snowpark_dataframe)

st.area_chart(snowpark_dataframe)

st.table(snowpark_dataframe)
10 changes: 10 additions & 0 deletions e2e/scripts/st_map.py
Expand Up @@ -18,11 +18,21 @@
import pandas as pd

import streamlit as st
from tests.streamlit.snowpark_mocks import DataFrame as MockedSnowparkDataFrame
from tests.streamlit.snowpark_mocks import Table as MockedSnowparkTable

# Empty map.

st.map()

# st.map with unevaluated Snowpark DataFrame

st.map(MockedSnowparkTable(is_map=True, num_of_rows=50000))

# st.map with unevaluated Snowpark Table

st.map(MockedSnowparkDataFrame(is_map=True, num_of_rows=50000))

# Simple map.

# Cast is needed due to mypy not understanding the outcome of dividing
Expand Down
9 changes: 8 additions & 1 deletion e2e/specs/st_arrow_unevaluated_snowpark_dataframe.js
Expand Up @@ -38,9 +38,14 @@ describe("st.DataFrame with unevaluated snowflake.snowpark.dataframe.DataFrame",
.should("have.length", 1)
});

it("table exists and is evaluated", () => {
cy.get("div [data-testid='stTable']")
.should("have.length", 1)
});

it("warning about data being capped exists", () => {
cy.get("div [data-testid='stCaptionContainer']")
.should("have.length", 4)
.should("have.length", 5)
});

it("warning about data being capped exists", () => {
Expand All @@ -52,6 +57,8 @@ describe("st.DataFrame with unevaluated snowflake.snowpark.dataframe.DataFrame",
.should("contain", "⚠️ Showing only 10k rows. Call collect() on the dataframe to show more.")
cy.getIndexed("div [data-testid='stCaptionContainer']", 3)
.should("contain", "⚠️ Showing only 10k rows. Call collect() on the dataframe to show more.")
cy.getIndexed("div [data-testid='stCaptionContainer']", 4)
.should("contain", "⚠️ Showing only 100 rows. Call collect() on the dataframe to show more.")
});

it("displays a line chart", () => {
Expand Down
21 changes: 17 additions & 4 deletions e2e/specs/st_map.spec.js
Expand Up @@ -19,16 +19,29 @@ describe("st.map", () => {
cy.loadApp("http://localhost:3000/");
});

it("displays 3 maps", () => {
cy.get(".element-container .stDeckGlJsonChart").should("have.length", 3)
it("displays 5 maps", () => {
cy.get(".element-container .stDeckGlJsonChart").should("have.length", 5)
});

it("displays 3 zoom buttons", () => {
cy.get(".element-container .zoomButton").should("have.length", 3)
it("displays 5 zoom buttons", () => {
cy.get(".element-container .zoomButton").should("have.length", 5)
})

it("warning about data being capped exists", () => {
cy.get("div [data-testid='stCaptionContainer']")
.should("have.length", 2)
})

it("warning about data being capped has proper message value", () => {
cy.getIndexed("div [data-testid='stCaptionContainer']", 0)
.should("contain", "⚠️ Showing only 10k rows. Call collect() on the dataframe to show more.")
cy.getIndexed("div [data-testid='stCaptionContainer']", 1)
.should("contain", "⚠️ Showing only 10k rows. Call collect() on the dataframe to show more.")
})

it("displays the correct snapshot", () => {
cy.get(".mapboxgl-canvas")
cy.get(".element-container", { waitForAnimations: true }).last().matchThemedSnapshots("stDeckGlJsonChart")
})

});
5 changes: 5 additions & 0 deletions lib/streamlit/elements/arrow.py
Expand Up @@ -132,6 +132,11 @@ def _arrow_table(self, data: Data = None) -> "DeltaGenerator":

"""

# Check if data is uncollected, and collect it but with 100 rows max, instead of 10k rows, which is done in all other cases.
# Avoid this and use 100 rows in st.table, because large tables render slowly, take too much screen space, and can crush the app.
if type_util.is_snowpark_data_object(data):
data = type_util.convert_anything_to_df(data, max_unevaluated_rows=100)

# If pandas.Styler uuid is not provided, a hash of the position
# of the element will be used. This will cause a rerender of the table
# when the position of the element is changed.
Expand Down
10 changes: 5 additions & 5 deletions lib/streamlit/elements/dataframe_selector.py
Expand Up @@ -47,7 +47,7 @@ def dataframe(

Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, Iterable, dict, or None
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table, Iterable, dict, or None
The data to display.

If 'data' is a pandas.Styler, it will be used to style its
Expand Down Expand Up @@ -117,7 +117,7 @@ def table(self, data: "Data" = None) -> "DeltaGenerator":

Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, Iterable, dict, or None
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table, Iterable, dict, or None
The table data.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
Expand Down Expand Up @@ -165,7 +165,7 @@ def line_chart(

Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, Iterable, dict or None
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table, Iterable, dict or None
Data to be plotted.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
Expand Down Expand Up @@ -248,7 +248,7 @@ def area_chart(

Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, Iterable, or dict
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table, Iterable, or dict
Data to be plotted.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
Expand Down Expand Up @@ -331,7 +331,7 @@ def bar_chart(

Parameters
----------
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, Iterable, or dict
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table, Iterable, or dict
Data to be plotted.
Pyarrow tables are not supported by Streamlit's legacy DataFrame serialization
(i.e. with `config.dataFrameSerialization = "legacy"`).
Expand Down
19 changes: 9 additions & 10 deletions lib/streamlit/elements/map.py
Expand Up @@ -22,6 +22,7 @@
from typing_extensions import Final, TypeAlias

import streamlit.elements.deck_gl_json_chart as deck_gl_json_chart
from streamlit import type_util
from streamlit.errors import StreamlitAPIException
from streamlit.proto.DeckGlJsonChart_pb2 import DeckGlJsonChart as DeckGlJsonChartProto
from streamlit.runtime.metrics_util import gather_metrics
Expand Down Expand Up @@ -95,7 +96,7 @@ def map(

Parameters
----------
data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict,
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table, Iterable, dict,
or None
The data to be plotted. Must have columns called 'lat', 'lon',
'latitude', or 'longitude'.
Expand Down Expand Up @@ -157,13 +158,17 @@ def _get_zoom_level(distance: float) -> int:


def to_deckgl_json(data: Data, zoom: Optional[int]) -> str:
if data is None:
return json.dumps(_DEFAULT_MAP)

# TODO(harahu): The ignore statement here is because iterables don't have
# the empty attribute. This is either a bug, or the documented data type
# is too broad. One or the other should be addressed, and the ignore
# statement removed.
if data is None or data.empty: # type: ignore[union-attr]
if hasattr(data, "empty") and data.empty: # type: ignore
return json.dumps(_DEFAULT_MAP)

data = type_util.convert_anything_to_df(data)

if "lat" in data:
lat = "lat"
elif "latitude" in data:
Expand All @@ -182,15 +187,9 @@ def to_deckgl_json(data: Data, zoom: Optional[int]) -> str:
'Map data must contain a column called "longitude" or "lon".'
)

# TODO(harahu): The ignore statement here is because iterables don't have
# the empty attribute. This is either a bug, or the documented data type
# is too broad. One or the other should be addressed, and the ignore
# statement removed.
if data[lon].isnull().values.any() or data[lat].isnull().values.any(): # type: ignore[index]
if data[lon].isnull().values.any() or data[lat].isnull().values.any():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we can remove the ignore statement here, we can also remove the comment from harahu.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's removed 👍

raise StreamlitAPIException("Latitude and longitude data must be numeric.")

data = pd.DataFrame(data)

min_lat = data[lat].min()
max_lat = data[lat].max()
min_lon = data[lon].min()
Expand Down
2 changes: 1 addition & 1 deletion lib/streamlit/elements/write.py
Expand Up @@ -181,7 +181,7 @@ def flush_buffer():
# Order matters!
if isinstance(arg, str):
string_buffer.append(arg)
elif type_util.is_snowpark_dataframe(arg):
elif type_util.is_snowpark_data_object(arg):
flush_buffer()
self.dg.dataframe(arg)
elif type_util.is_dataframe_like(arg):
Expand Down
13 changes: 13 additions & 0 deletions lib/streamlit/string_util.py
Expand Up @@ -144,3 +144,16 @@ def generate_download_filename_from_title(title_string: str) -> str:
file_name_string = clean_filename(title_string)
title_string = snake_case_to_camel_case(file_name_string)
return append_date_time_to_string(title_string)


def simplify_number(num: int) -> str:
"""Simplifies number into Human readable format, returns str"""
num_converted = float("{:.2g}".format(num))
magnitude = 0
while abs(num_converted) >= 1000:
magnitude += 1
num_converted /= 1000.0
return "{}{}".format(
"{:f}".format(num_converted).rstrip("0").rstrip("."),
["", "k", "m", "b", "t"][magnitude],
)
29 changes: 20 additions & 9 deletions lib/streamlit/type_util.py
Expand Up @@ -40,6 +40,7 @@
import streamlit as st
from streamlit import errors
from streamlit import logger as _logger
from streamlit import string_util

if TYPE_CHECKING:
import graphviz
Expand Down Expand Up @@ -204,6 +205,7 @@ def get_fqn_type(obj: object) -> str:
_NUMPY_ARRAY_TYPE_STR: Final = "numpy.ndarray"
_SNOWPARK_DF_TYPE_STR: Final = "snowflake.snowpark.dataframe.DataFrame"
_SNOWPARK_DF_ROW_TYPE_STR: Final = "snowflake.snowpark.row.Row"
_SNOWPARK_TABLE_TYPE_STR: Final = "snowflake.snowpark.table.Table"

_DATAFRAME_LIKE_TYPES: Final[tuple[str, ...]] = (
_PANDAS_DF_TYPE_STR,
Expand Down Expand Up @@ -240,10 +242,12 @@ def is_dataframe_like(obj: object) -> TypeGuard[DataFrameLike]:
return any(is_type(obj, t) for t in _DATAFRAME_LIKE_TYPES)


def is_snowpark_dataframe(obj: object) -> bool:
"""True if obj is of type snowflake.snowpark.dataframe.DataFrame or
def is_snowpark_data_object(obj: object) -> bool:
"""True if obj is of type snowflake.snowpark.dataframe.DataFrame, snowflake.snowpark.table.Table or
True when obj is a list which contains snowflake.snowpark.row.Row,
False otherwise"""
if is_type(obj, _SNOWPARK_TABLE_TYPE_STR):
return True
if is_type(obj, _SNOWPARK_DF_TYPE_STR):
return True
if not isinstance(obj, list):
Expand Down Expand Up @@ -416,13 +420,19 @@ def is_sequence(seq: Any) -> bool:
return True


def convert_anything_to_df(df: Any) -> DataFrame:
def convert_anything_to_df(
df: Any, max_unevaluated_rows: int = MAX_UNEVALUATED_DF_ROWS
) -> DataFrame:
"""Try to convert different formats to a Pandas Dataframe.

Parameters
----------
df : ndarray, Iterable, dict, DataFrame, Styler, pa.Table, None, dict, list, or any

max_unevaluated_rows: int
If unevaluated data is detected this func will evaluate it,
taking max_unevaluated_rows, defaults to 10k and 100 for st.table

Returns
-------
pandas.DataFrame
Expand All @@ -445,11 +455,12 @@ def convert_anything_to_df(df: Any) -> DataFrame:
if is_type(df, "numpy.ndarray") and len(df.shape) == 0:
return pd.DataFrame([])

if is_type(df, _SNOWPARK_DF_TYPE_STR) and not isinstance(df, list):
df = pd.DataFrame(df.take(MAX_UNEVALUATED_DF_ROWS))
if df.shape[0] == MAX_UNEVALUATED_DF_ROWS:
if is_type(df, _SNOWPARK_DF_TYPE_STR) or is_type(df, _SNOWPARK_TABLE_TYPE_STR):
df = pd.DataFrame(df.take(max_unevaluated_rows))
if df.shape[0] == max_unevaluated_rows:
st.caption(
f"⚠️ Showing only 10k rows. Call `collect()` on the dataframe to show more."
f"⚠️ Showing only {string_util.simplify_number(max_unevaluated_rows)} rows. "
"Call `collect()` on the dataframe to show more."
)
return df

Expand Down Expand Up @@ -492,14 +503,14 @@ def ensure_iterable(obj: Union[DataFrame, Iterable[V_co]]) -> Iterable[Any]:

Parameters
----------
obj : list, tuple, numpy.ndarray, pandas.Series, pandas.DataFrame or snowflake.snowpark.dataframe.DataFrame
obj : list, tuple, numpy.ndarray, pandas.Series, pandas.DataFrame, snowflake.snowpark.dataframe.DataFrame or snowflake.snowpark.table.Table

Returns
-------
iterable

"""
if is_snowpark_dataframe(obj):
if is_snowpark_data_object(obj):
obj = convert_anything_to_df(obj)

if is_dataframe(obj):
Expand Down
14 changes: 14 additions & 0 deletions lib/tests/streamlit/dataframe_selector_test.py
Expand Up @@ -23,6 +23,7 @@
import streamlit
from streamlit.delta_generator import DeltaGenerator
from tests.streamlit.snowpark_mocks import DataFrame as MockSnowparkDataFrame
from tests.streamlit.snowpark_mocks import Table as MockSnowparkTable
from tests.testutil import patch_config_options

DATAFRAME = pd.DataFrame([["A", "B", "C", "D"], [28, 55, 43, 91]], index=["a", "b"]).T
Expand Down Expand Up @@ -65,6 +66,19 @@ def test_arrow_dataframe_with_snowpark_dataframe(
snowpark_df, 100, 200, use_container_width=False
)

@patch.object(DeltaGenerator, "_legacy_dataframe")
@patch.object(DeltaGenerator, "_arrow_dataframe")
@patch_config_options({"global.dataFrameSerialization": "arrow"})
def test_arrow_dataframe_with_snowpark_table(
self, arrow_dataframe, legacy_dataframe
):
snowpark_table = MockSnowparkTable()
streamlit.dataframe(snowpark_table, 100, 200)
legacy_dataframe.assert_not_called()
arrow_dataframe.assert_called_once_with(
snowpark_table, 100, 200, use_container_width=False
)

@patch.object(DeltaGenerator, "_legacy_table")
@patch.object(DeltaGenerator, "_arrow_table")
@patch_config_options({"global.dataFrameSerialization": "legacy"})
Expand Down
38 changes: 38 additions & 0 deletions lib/tests/streamlit/elements/map_test.py
Expand Up @@ -22,6 +22,8 @@
import streamlit as st
from streamlit.elements.map import _DEFAULT_MAP, _DEFAULT_ZOOM_LEVEL
from tests.delta_generator_test_case import DeltaGeneratorTestCase
from tests.streamlit.snowpark_mocks import DataFrame as MockedSnowparkDataFrame
from tests.streamlit.snowpark_mocks import Table as MockedSnowparkTable

df1 = pd.DataFrame({"lat": [1, 2, 3, 4], "lon": [10, 20, 30, 40]})

Expand Down Expand Up @@ -93,3 +95,39 @@ def test_nan_exception(self):
st.map(df)

self.assertTrue("data must be numeric." in str(ctx.exception))

def test_unevaluated_snowpark_table(self):
"""Test st.map with unevaluated Snowpark Table"""
mocked_snowpark_table = MockedSnowparkTable(is_map=True, num_of_rows=50000)
st.map(mocked_snowpark_table)

c = json.loads(self.get_delta_from_queue().new_element.deck_gl_json_chart.json)

self.assertIsNotNone(c.get("initialViewState"))
self.assertIsNotNone(c.get("layers"))
self.assertIsNone(c.get("mapStyle"))
self.assertEqual(len(c.get("layers")), 1)
self.assertEqual(c.get("initialViewState").get("pitch"), 0)
self.assertEqual(c.get("layers")[0].get("@@type"), "ScatterplotLayer")

"""Check if map data was cut to 10k rows"""
self.assertEqual(len(c["layers"][0]["data"]), 10000)

def test_unevaluated_snowpark_dataframe(self):
"""Test st.map with unevaluated Snowpark DataFrame"""
mocked_snowpark_dataframe = MockedSnowparkDataFrame(
is_map=True, num_of_rows=50000
)
st.map(mocked_snowpark_dataframe)

c = json.loads(self.get_delta_from_queue().new_element.deck_gl_json_chart.json)

self.assertIsNotNone(c.get("initialViewState"))
self.assertIsNotNone(c.get("layers"))
self.assertIsNone(c.get("mapStyle"))
self.assertEqual(len(c.get("layers")), 1)
self.assertEqual(c.get("initialViewState").get("pitch"), 0)
self.assertEqual(c.get("layers")[0].get("@@type"), "ScatterplotLayer")

"""Check if map data was cut to 10k rows"""
self.assertEqual(len(c["layers"][0]["data"]), 10000)