Skip to content

Commit

Permalink
fix: Clean up snowflake to_spark_df() (feast-dev#3607)
Browse files Browse the repository at this point in the history
Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
  • Loading branch information
sfc-gh-madkins committed Apr 24, 2023
1 parent 902f23f commit e8e643e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 33 deletions.
6 changes: 3 additions & 3 deletions docs/reference/offline-stores/overview.md
Expand Up @@ -46,11 +46,11 @@ Below is a matrix indicating which `RetrievalJob`s support what functionality.
| --------------------------------- | --- | --- | --- | --- | --- | --- | --- |
| export to dataframe | yes | yes | yes | yes | yes | yes | yes |
| export to arrow table | yes | yes | yes | yes | yes | yes | yes |
| export to arrow batches | no | no | yes | yes | no | no | no |
| export to SQL | no | yes | yes | yes | yes | no | yes |
| export to arrow batches | no | no | no | yes | no | no | no |
| export to SQL | no | yes | yes | yes | yes | no | yes |
| export to data lake (S3, GCS, etc.) | no | no | yes | no | yes | no | no |
| export to data warehouse | no | yes | yes | yes | yes | no | no |
| export as Spark dataframe | no | no | yes | no | no | yes | no |
| export as Spark dataframe | no | no | yes | no | no | yes | no |
| local execution of Python-based on-demand transforms | yes | yes | yes | yes | yes | no | yes |
| remote execution of Python-based on-demand transforms | no | no | no | no | no | no | no |
| persist results in the offline store | yes | yes | yes | yes | yes | yes | no |
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/offline-stores/snowflake.md
Expand Up @@ -53,7 +53,7 @@ Below is a matrix indicating which functionality is supported by `SnowflakeRetri
| ----------------------------------------------------- | --------- |
| export to dataframe | yes |
| export to arrow table | yes |
| export to arrow batches | yes |
| export to arrow batches | yes |
| export to SQL | yes |
| export to data lake (S3, GCS, etc.) | yes |
| export to data warehouse | yes |
Expand Down
8 changes: 0 additions & 8 deletions sdk/python/feast/errors.py
Expand Up @@ -56,14 +56,6 @@ def __init__(self, name, project=None):
super().__init__(f"Feature view {name} does not exist")


class InvalidSparkSessionException(Exception):
def __init__(self, spark_arg):
super().__init__(
f" Need Spark Session to convert results to spark data frame\
recieved {type(spark_arg)} instead. "
)


class OnDemandFeatureViewNotFoundException(FeastObjectNotFoundException):
def __init__(self, name, project=None):
if project:
Expand Down
32 changes: 11 additions & 21 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Expand Up @@ -28,11 +28,7 @@

from feast import OnDemandFeatureView
from feast.data_source import DataSource
from feast.errors import (
EntitySQLEmptyResults,
InvalidEntityType,
InvalidSparkSessionException,
)
from feast.errors import EntitySQLEmptyResults, InvalidEntityType
from feast.feature_logging import LoggingConfig, LoggingSource
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
Expand Down Expand Up @@ -528,28 +524,22 @@ def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
"""

try:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError

raise FeastExtrasDependencyImportError("spark", str(e))

if isinstance(spark_session, SparkSession):
arrow_batches = self.to_arrow_batches()
spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

if arrow_batches:
spark_df = reduce(
DataFrame.unionAll,
[
spark_session.createDataFrame(batch.to_pandas())
for batch in arrow_batches
],
)
return spark_df
else:
raise EntitySQLEmptyResults(self.to_sql())
else:
raise InvalidSparkSessionException(spark_session)
# This can be improved by parallelizing the read of chunks
pandas_batches = self.to_pandas_batches()

spark_df = reduce(
DataFrame.unionAll,
[spark_session.createDataFrame(batch) for batch in pandas_batches],
)
return spark_df

def persist(
self,
Expand Down

0 comments on commit e8e643e

Please sign in to comment.