Skip to content

Commit

Permalink
Lazy-load pandas and pyarrow to improve performance (#8125)
Browse files Browse the repository at this point in the history
## Describe your changes

Lazy-load `pandas` and `pyarrow` only when required (e.g. usage of
`st.dataframe`).

This PR also includes a couple of other small refactorings related to
typing and imports.

## GitHub Issue Link (if applicable)

Related to #6066

## Testing Plan

- Added e2e test to ensure that `pyarrow` and `pandas` are lazy-loaded. 
---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
LukasMasuch committed Feb 9, 2024
1 parent 5a6c331 commit 1995def
Show file tree
Hide file tree
Showing 15 changed files with 219 additions and 124 deletions.
5 changes: 2 additions & 3 deletions e2e_playwright/lazy_loaded_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@
"altair",
"graphviz",
"watchdog",
"pandas",
"pyarrow",
"streamlit.emojis",
"streamlit.external",
"streamlit.vendor.pympler",
"streamlit.watcher.event_based_path_watcher",
# TODO(lukasmasuch): Lazy load more packages:
# "streamlit.hello",
# "pandas",
# "pyarrow",
# "numpy",
# "matplotlib",
# "plotly",
# "pillow",
# "watchdog",
]

for module in lazy_loaded_modules:
Expand Down
2 changes: 1 addition & 1 deletion e2e_playwright/lazy_loaded_modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
def test_lazy_loaded_modules_are_not_imported(app: Page):
"""Test that lazy loaded modules are not imported when the page is loaded."""
markdown_elements = app.get_by_test_id("stMarkdown")
expect(markdown_elements).to_have_count(11)
expect(markdown_elements).to_have_count(13)
for element in markdown_elements.all():
expect(element).to_have_text(re.compile(r".*not loaded.*"))
22 changes: 15 additions & 7 deletions lib/streamlit/components/v1/component_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

from __future__ import annotations

from typing import Any

import pandas as pd
from typing import TYPE_CHECKING, Any

from streamlit import type_util
from streamlit.elements.lib import pandas_styler_utils
from streamlit.proto.Components_pb2 import ArrowTable as ArrowTableProto

if TYPE_CHECKING:
from pandas import DataFrame, Index, Series


def marshall(
proto: ArrowTableProto, data: Any, default_uuid: str | None = None
Expand All @@ -50,7 +51,7 @@ def marshall(
_marshall_data(proto, df)


def _marshall_index(proto: ArrowTableProto, index: pd.Index) -> None:
def _marshall_index(proto: ArrowTableProto, index: Index) -> None:
"""Marshall pandas.DataFrame index into an ArrowTable proto.
Parameters
Expand All @@ -63,12 +64,14 @@ def _marshall_index(proto: ArrowTableProto, index: pd.Index) -> None:
Will default to RangeIndex (0, 1, 2, ..., n) if no index is provided.
"""
import pandas as pd

index = map(type_util.maybe_tuple_to_list, index.values)
index_df = pd.DataFrame(index)
proto.index = type_util.data_frame_to_bytes(index_df)


def _marshall_columns(proto: ArrowTableProto, columns: pd.Series) -> None:
def _marshall_columns(proto: ArrowTableProto, columns: Series) -> None:
"""Marshall pandas.DataFrame columns into an ArrowTable proto.
Parameters
Expand All @@ -81,12 +84,14 @@ def _marshall_columns(proto: ArrowTableProto, columns: pd.Series) -> None:
Will default to RangeIndex (0, 1, 2, ..., n) if no column labels are provided.
"""
import pandas as pd

columns = map(type_util.maybe_tuple_to_list, columns.values)
columns_df = pd.DataFrame(columns)
proto.columns = type_util.data_frame_to_bytes(columns_df)


def _marshall_data(proto: ArrowTableProto, df: pd.DataFrame) -> None:
def _marshall_data(proto: ArrowTableProto, df: DataFrame) -> None:
"""Marshall pandas.DataFrame data into an ArrowTable proto.
Parameters
Expand All @@ -101,7 +106,7 @@ def _marshall_data(proto: ArrowTableProto, df: pd.DataFrame) -> None:
proto.data = type_util.data_frame_to_bytes(df)


def arrow_proto_to_dataframe(proto: ArrowTableProto) -> pd.DataFrame:
def arrow_proto_to_dataframe(proto: ArrowTableProto) -> DataFrame:
"""Convert ArrowTable proto to pandas.DataFrame.
Parameters
Expand All @@ -110,12 +115,15 @@ def arrow_proto_to_dataframe(proto: ArrowTableProto) -> pd.DataFrame:
Output. pandas.DataFrame
"""

if type_util.is_pyarrow_version_less_than("14.0.1"):
raise RuntimeError(
"The installed pyarrow version is not compatible with this component. "
"Please upgrade to 14.0.1 or higher: pip install -U pyarrow"
)

import pandas as pd

data = type_util.bytes_to_data_frame(proto.data)
index = type_util.bytes_to_data_frame(proto.index)
columns = type_util.bytes_to_data_frame(proto.columns)
Expand Down
14 changes: 7 additions & 7 deletions lib/streamlit/connections/snowflake_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@
from datetime import timedelta
from typing import TYPE_CHECKING, cast

import pandas as pd

from streamlit.connections import BaseConnection
from streamlit.connections.util import running_in_sis
from streamlit.errors import StreamlitAPIException
from streamlit.runtime.caching import cache_data

if TYPE_CHECKING:
from pandas import DataFrame
from snowflake.connector.cursor import SnowflakeCursor # type:ignore[import]
from snowflake.snowpark.session import Session # type:ignore[import]

from snowflake.connector import ( # type:ignore[import] # isort: skip
SnowflakeConnection as InternalSnowflakeConnection,
)
from snowflake.connector.cursor import SnowflakeCursor # type:ignore[import]
from snowflake.snowpark.session import Session # type:ignore[import]


class SnowflakeConnection(BaseConnection["InternalSnowflakeConnection"]):
Expand Down Expand Up @@ -125,7 +125,7 @@ def query(
show_spinner: bool | str = "Running `snowflake.query(...)`.",
params=None,
**kwargs,
) -> pd.DataFrame:
) -> DataFrame:
"""Run a read-only SQL query.
This method implements both query result caching (with caching behavior
Expand Down Expand Up @@ -202,7 +202,7 @@ def query(
),
wait=wait_fixed(1),
)
def _query(sql: str) -> pd.DataFrame:
def _query(sql: str) -> DataFrame:
cur = self._instance.cursor()
cur.execute(sql, params=params, **kwargs)
return cur.fetch_pandas_all()
Expand All @@ -224,7 +224,7 @@ def _query(sql: str) -> pd.DataFrame:

def write_pandas(
self,
df: pd.DataFrame,
df: DataFrame,
table_name: str,
database: str | None = None,
schema: str | None = None,
Expand Down
13 changes: 7 additions & 6 deletions lib/streamlit/connections/snowpark_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
# way to configure this at a per-line level :(
# mypy: no-warn-unused-ignores

from __future__ import annotations

import threading
from collections import ChainMap
from contextlib import contextmanager
from datetime import timedelta
from typing import TYPE_CHECKING, Iterator, Optional, Union, cast

import pandas as pd
from typing import TYPE_CHECKING, Iterator, cast

from streamlit.connections import BaseConnection
from streamlit.connections.util import (
Expand All @@ -36,6 +36,7 @@
from streamlit.runtime.caching import cache_data

if TYPE_CHECKING:
from pandas import DataFrame
from snowflake.snowpark.session import Session # type:ignore[import]


Expand Down Expand Up @@ -96,8 +97,8 @@ def _connect(self, **kwargs) -> "Session":
def query(
self,
sql: str,
ttl: Optional[Union[float, int, timedelta]] = None,
) -> pd.DataFrame:
ttl: float | int | timedelta | None = None,
) -> DataFrame:
"""Run a read-only SQL query.
This method implements both query result caching (with caching behavior
Expand Down Expand Up @@ -144,7 +145,7 @@ def query(
retry=retry_if_exception_type(SnowparkServerException),
wait=wait_fixed(1),
)
def _query(sql: str) -> pd.DataFrame:
def _query(sql: str) -> DataFrame:
with self._lock:
return self._instance.sql(sql).to_pandas()

Expand Down
18 changes: 10 additions & 8 deletions lib/streamlit/connections/sql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections import ChainMap
from copy import deepcopy
from datetime import timedelta
from typing import TYPE_CHECKING, List, Optional, Union, cast

import pandas as pd
from typing import TYPE_CHECKING, List, cast

from streamlit.connections import BaseConnection
from streamlit.connections.util import extract_from_dict
from streamlit.errors import StreamlitAPIException
from streamlit.runtime.caching import cache_data

if TYPE_CHECKING:
from pandas import DataFrame
from sqlalchemy.engine import Connection as SQLAlchemyConnection
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -124,12 +124,12 @@ def query(
sql: str,
*, # keyword-only arguments:
show_spinner: bool | str = "Running `sql.query(...)`.",
ttl: Optional[Union[float, int, timedelta]] = None,
index_col: Optional[Union[str, List[str]]] = None,
chunksize: Optional[int] = None,
ttl: float | int | timedelta | None = None,
index_col: str | List[str] | None = None,
chunksize: int | None = None,
params=None,
**kwargs,
) -> pd.DataFrame:
) -> DataFrame:
"""Run a read-only query.
This method implements both query result caching (with caching behavior
Expand Down Expand Up @@ -211,7 +211,9 @@ def _query(
chunksize=None,
params=None,
**kwargs,
) -> pd.DataFrame:
) -> DataFrame:
import pandas as pd

instance = self._instance.connect()
return pd.read_sql(
text(sql),
Expand Down
17 changes: 10 additions & 7 deletions lib/streamlit/elements/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Union, cast

import pyarrow as pa
from typing_extensions import TypeAlias

from streamlit import type_util
Expand All @@ -33,6 +32,7 @@
from streamlit.runtime.metrics_util import gather_metrics

if TYPE_CHECKING:
import pyarrow as pa
from numpy import ndarray
from pandas import DataFrame, Index, Series
from pandas.io.formats.style import Styler
Expand All @@ -44,7 +44,7 @@
"Series",
"Styler",
"Index",
pa.Table,
"pa.Table",
"ndarray",
Iterable,
Dict[str, List[Any]],
Expand Down Expand Up @@ -184,6 +184,7 @@ def dataframe(
height: 350px
"""
import pyarrow as pa

# Convert the user provided column config into the frontend compatible format:
column_config_mapping = process_config_mapping(column_config)
Expand Down Expand Up @@ -283,7 +284,7 @@ def table(self, data: Data = None) -> "DeltaGenerator":
return self.dg._enqueue("arrow_table", proto)

@gather_metrics("add_rows")
def add_rows(self, data: "Data" = None, **kwargs) -> Optional["DeltaGenerator"]:
def add_rows(self, data: "Data" = None, **kwargs) -> "DeltaGenerator" | None:
"""Concatenate a dataframe to the bottom of the current one.
Parameters
Expand Down Expand Up @@ -342,7 +343,7 @@ def dg(self) -> "DeltaGenerator":
return cast("DeltaGenerator", self)


def marshall(proto: ArrowProto, data: Data, default_uuid: Optional[str] = None) -> None:
def marshall(proto: ArrowProto, data: Data, default_uuid: str | None = None) -> None:
"""Marshall pandas.DataFrame into an Arrow proto.
Parameters
Expand All @@ -353,12 +354,14 @@ def marshall(proto: ArrowProto, data: Data, default_uuid: Optional[str] = None)
data : pandas.DataFrame, pandas.Styler, pyarrow.Table, numpy.ndarray, pyspark.sql.DataFrame, snowflake.snowpark.DataFrame, Iterable, dict, or None
Something that is or can be converted to a dataframe.
default_uuid : Optional[str]
default_uuid : str | None
If pandas.Styler UUID is not provided, this value will be used.
This attribute is optional and only used for pandas.Styler, other elements
(e.g. charts) can ignore it.
"""
import pyarrow as pa

if type_util.is_pandas_styler(data):
# default_uuid is a string only if the data is a `Styler`,
# and `None` otherwise.
Expand Down
22 changes: 17 additions & 5 deletions lib/streamlit/elements/arrow_altair.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@
from contextlib import nullcontext
from datetime import date
from enum import Enum
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Sequence, Tuple, cast

import pandas as pd
from pandas.api.types import infer_dtype, is_integer_dtype
from typing_extensions import Literal
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
List,
Literal,
Sequence,
Tuple,
cast,
)

import streamlit.elements.arrow_vega_lite as arrow_vega_lite
from streamlit import type_util
Expand All @@ -47,6 +53,7 @@

if TYPE_CHECKING:
import altair as alt
import pandas as pd

from streamlit.delta_generator import DeltaGenerator

Expand Down Expand Up @@ -848,6 +855,8 @@ def _melt_data(
new_color_column_name: str,
) -> pd.DataFrame:
"""Converts a wide-format dataframe to a long-format dataframe."""
import pandas as pd
from pandas.api.types import infer_dtype

melted_df = pd.melt(
df,
Expand Down Expand Up @@ -1083,6 +1092,8 @@ def _convert_col_names_to_str_in_place(
size_column: str | None,
) -> Tuple[str | None, List[str], str | None, str | None]:
"""Converts column names to strings, since Vega-Lite does not accept ints, etc."""
import pandas as pd

column_names = list(df.columns) # list() converts RangeIndex, etc, to regular list.
str_column_names = [str(c) for c in column_names]
df.columns = pd.Index(str_column_names)
Expand Down Expand Up @@ -1186,6 +1197,7 @@ def _get_scale(df: pd.DataFrame, column_name: str | None) -> alt.Scale:

def _get_axis_config(df: pd.DataFrame, column_name: str | None, grid: bool) -> alt.Axis:
import altair as alt
from pandas.api.types import is_integer_dtype

if column_name is not None and is_integer_dtype(df[column_name]):
# Use a max tick size of 1 for integer columns (prevents zoom into float numbers)
Expand Down

0 comments on commit 1995def

Please sign in to comment.