In [20]:
%%configure -f
{ "conf":{
          "spark.pyspark.python": "python3",
          "spark.pyspark.virtualenv.enabled": "true",
          "spark.pyspark.virtualenv.type":"native",
          "spark.pyspark.virtualenv.bin.path":"/usr/bin/virtualenv",
          "spark.sql.catalogImplementation":"hive"
         }
}

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
19,application_1616797268065_0020,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
19,application_1616797268065_0020,pyspark,idle,Link,Link,,✔


In [21]:
# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

spark = SparkSession.builder.appName("PySparkApp").getOrCreate()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [22]:
# Set feature Store s3 path
BUCKET = "jmackay-offline-feature-store-test"
PREFIX = "sagemaker-featurestore/760493221347/sagemaker/us-west-2/offline-store/my-feature-group-077"
FEATURE_STORE_PATH = f"s3://{BUCKET}/{PREFIX}/data"

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [23]:
# Read and show feature store DF
feature_store_df = spark.read.parquet(FEATURE_STORE_PATH)
feature_store_df.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+--------------------+---------+--------------------+-------------------+----------+----+-----+---+----+
| id|          event_time|feature_1|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+---+--------------------+---------+--------------------+-------------------+----------+----+-----+---+----+
|  4|2021-03-03T20:51:...|      4.3|2021-03-19 22:12:...|2021-03-19 22:07:02|     false|2021|    3|  3|  20|
|  5|2021-03-03T20:51:...|      5.3|2021-03-19 22:12:...|2021-03-19 22:07:02|     false|2021|    3|  3|  20|
|  4|2021-03-02T20:51:...|      4.2|2021-03-19 22:12:...|2021-03-19 22:07:02|     false|2021|    3|  2|  20|
|  5|2021-03-02T20:51:...|      5.2|2021-03-19 22:12:...|2021-03-19 22:07:02|     false|2021|    3|  2|  20|
|  4|2021-03-01T20:51:...|      4.1|2021-03-19 22:12:...|2021-03-19 22:07:01|     false|2021|    3|  1|  20|
|  5|2021-03-01T20:51:...|      5.1|2021-03-19 22:12:...|2021-03-19 22:07:01|     false|2021|    3|  1|  20|
|  4|2021-03-05T20:

In [24]:
# Create and show the query DF
sc = spark.sparkContext
query_df_schema = StructType([
    StructField('id', StringType(), False),
    StructField('query_time', StringType(), False)
])

query_df = spark.createDataFrame(
    [
        ["1", "2021-03-05T00:00:00Z"],
        ["4", "2021-03-04T00:00:00Z"],
        ["5", "2021-03-25T00:00:00Z"],
    ],
    query_df_schema)


# To instead load this query df from s3:
# QUERY_PATH = f"s3://{BUCKET}/{PREFIX}/test_query.parquet"
# query_df = spark.read.parquet(QUERY_PATH)

query_df.show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+--------------------+
| id|          query_time|
+---+--------------------+
|  1|2021-03-05T00:00:00Z|
|  4|2021-03-04T00:00:00Z|
|  5|2021-03-25T00:00:00Z|
+---+--------------------+

In [25]:
# Compute min and max times over our query data for filtering, in one pass for performance
minmax_time = query_df.agg(sql_min("query_time"), sql_max("query_time")).collect()
min_time, max_time = minmax_time[0]["min(query_time)"], minmax_time[0]["max(query_time)"]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [26]:
# Filter deleted records out
filtered = feature_store_df.filter(~feature_store_df.is_deleted)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [27]:
# Filter out records from after highest query time and before staleness window
# This is a performance optimization; doing this prior to individual (shopper, query_time) filtering will be faster
allowed_staleness_days = 4
filtered = filtered.filter(
    (datediff(lit(min_time), feature_store_df.event_time) <= allowed_staleness_days)
    & (feature_store_df.event_time <= max_time)
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
# Join with query set; drop duplicate id field
joined = filtered.join(query_df, feature_store_df.id == query_df.id, "inner").drop(query_df.id)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [29]:
# Filter out data from after query time or before query time minus staleness window
drop_future_and_stale = joined.filter(
    (feature_store_df.event_time <= query_df.query_time)
    & (datediff(query_df.query_time, feature_store_df.event_time) <= allowed_staleness_days)
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [30]:
# Group by id and take latest record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.id, x))  # to RDD with KVPs so we can use efficient reduceByKey
    .reduceByKey(
        lambda x, y: x if ((x.event_time, x.api_invocation_time) >= (y.event_time, y.api_invocation_time)) else y
    )  # Use API invocation time as tie-breaker
    .values()  # drop keys
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [31]:
# Convert to DF
latest_df = take_latest.toDF(drop_future_and_stale.schema)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [32]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
selected = latest_df.drop(*columns_to_drop)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [33]:
# Show query result
selected.show()

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# selected.write.parquet(OUTPUT_PATH, mode="overwrite")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+--------------------+---------+
| id|          event_time|feature_1|
+---+--------------------+---------+
|  4|2021-03-03T20:51:...|      4.3|
|  1|2021-03-04T22:51:...|      1.4|
+---+--------------------+---------+