### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries to implement Time Travel
The following notebook builds out a set of Dataframes and Spark queries that provide a pattern for using "Time Travel" capabilities that leverage SageMaker Feature Store. We will demonstrate how to build a point-in-time feature sets. Techniques include building Spark Dataframes and using cross-table joins and query filters to reduce the dataset. 

In [None]:
!pip install Faker

In [1]:
# Faker
from faker import Faker

# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())

#print(f"Spark classpath_jars: {classpath}")

In [2]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .getOrCreate())


In [3]:
sc = spark.sparkContext
print(sc.version)

2.3.4


In [4]:
SEED = 123456
faker = Faker()
faker.seed_locale('en_US', 0)
faker.seed_instance(SEED)

In [5]:
import sagemaker

BUCKET = sagemaker.Session().default_bucket()
print(BUCKET)

sagemaker-us-east-1-572539092864


In [6]:
import os

BASE_PREFIX = "sagemaker-featurestore-demo"
OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'
print(OFFLINE_STORE_BASE_URI)

AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')
print(AGG_PREFIX)

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"

s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo
sagemaker-featurestore-demo/aggregated


In [7]:
from sagemaker.s3 import S3Downloader

file_list = S3Downloader.list(AGG_FEATURES_PATH_S3)

print(f'Using S3 path: {AGG_FEATURES_PATH_S3}')
print("Found files: \n" + "\n".join(file_list))

Using S3 path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/
Found files: 
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/_SUCCESS
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/part-00000-8905a404-2acc-42b8-9f77-bb12fa21ba4e-c000.csv


In [None]:
# Read and show feature store DF
#feature_store_df = spark.read.parquet(AGG_FEATURES_PATH_PARQUET)

# copy CSV file from above
##partfile = "part-00000-736d45d1-ee46-4745-b88a-7ba9460b239c-c000.csv"

##FEATURE_STORE_PATH_CSV = f"s3a://{BUCKET}/{PREFIX}/{partfile}"

In [8]:
transactions_df = spark.read.options(Header=True).csv(AGG_FEATURES_PATH_PARQUET)

In [9]:
transactions_df.printSchema()
transactions_df.count()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- cc_num: string (nullable = true)
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- num_trans_last_1d: string (nullable = true)
 |-- avg_amt_last_1d: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)



500001

In [10]:
transactions_df.show(5)

+--------------------+--------------------+----------------+-------+-----------+------------------+----------------+-----------------+------------------+--------------------+--------------------+------------------+
|                 tid|          event_time|          cc_num| amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|num_trans_last_1d|   avg_amt_last_1d|          amt_ratio1|          amt_ratio2|       count_ratio|
+--------------------+--------------------+----------------+-------+-----------+------------------+----------------+-----------------+------------------+--------------------+--------------------+------------------+
|ec9e8c973eabb24ff...|2021-01-01T02:10:...|4028853934607849|6698.67|          0|                 1|         6698.67|                1|           6698.67|                 1.0|                 1.0|               1.0|
|6c708b2bea5bec096...|2021-01-01T03:11:...|4028853934607849| 338.97|          0|                 1|          338.97|                2|      

In [11]:
OFFLINE_STORE_URI = \
    "s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618489640/data/"
#"s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618437897/data/"
#OFFLINE_STORE_URI = "s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618411941/data/year=2021/month=04/day=14/hour=14/"


In [12]:
##feature_store_df = spark.read.parquet(OFFLINE_STORE_BASE_URI + "/data")

feature_store_df = spark.read.parquet(OFFLINE_STORE_URI)
feature_store_df.printSchema()
feature_store_df.count()

root
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



1000

In [13]:
feature_store_df.show(5)

+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|6c7825ae8dde6f430...|4047269399322294|               20|         929.34|2021-01-31 23:32:29|1.618490011E9|2021-04-15 12:38:...|2021-04-15 12:33:30|     false|2021|    4| 15|  12|
|298e2cab41aed7dbf...|4387164162852064|               16|         287.72|2021-01-31 23:27:24|1.618490011E9|2021-04-15 12:38:...|2021-04-15 12:33:30|     false|2021|    4| 15|  12|
|065076eddd6bb8079...|4447228755741220|               11|         307.44|2021-01-31 22:03:52|1.61849

In [14]:
# Join the raw transactons table to the aggregate feature table 

enhanced_df = (transactions_df.join(feature_store_df, transactions_df.tid == feature_store_df.tid, "left_outer")
    .drop(transactions_df.tid)
    .drop(transactions_df.cc_num)
    .drop(transactions_df.num_trans_last_1d)
    .drop(transactions_df.avg_amt_last_1d)
    .drop(transactions_df.event_time))

enhanced_df.printSchema()
enhanced_df.count()

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



500001

In [None]:
#enhanced_df.createOrReplaceTempView("features")

In [15]:
import calendar

# create cutoff time
cutoff_time = datetime.datetime(2021, 2, 1)
cutoff_cal = calendar.timegm(cutoff_time.timetuple())

print(f'Cutoff time: {cutoff_time}')
print(f'Cutoff cal:  {cutoff_cal}')

Cutoff time: 2021-02-01 00:00:00
Cutoff cal:  1612137600


In [None]:
## events = spark.sql("SELECT transactionid, transactionamt, transactiondt, write_time FROM features WHERE write_time < '" + str(cutoff_time) + "' ORDER BY transactiondt")
#events_window = spark.sql("SELECT tid, event_time, trans_time, api_invocation_time, cc_num, fraud_label, amount, num_trans_last_1d, avg_amt_last_1d \
#                    FROM features WHERE event_time <= '" + str(cutoff_time) + "'")

# removed "" ORDER BY event_time"
#print ("Count: " + str(events_window.count()))

In [None]:
#events_window.show(5)

### Sample Time Travel query from Studio

##### Simplified query for datasets up to 100 GB

SELECT *
FROM
    (SELECT *,
         row_number()
        OVER (PARTITION BY EventTime
    ORDER BY  EventTime desc, Api_Invocation_Time DESC, write_time DESC) AS row_number
    FROM sagemaker_featurestore.identity-feature-group-03-20-32-44-1614803787
    where EventTime <= timestamp '<timestamp>')
    -- replace timestamp '<timestamp>' with just <timestamp>  if EventTimeFeature is of type fractional
WHERE row_number = 1 and
NOT is_deleted

### Build Query Dataframe that spans the intended time window

In [16]:
import random

NUM_RANDOM_SAMPLES = 100

cc_num_list = transactions_df.rdd.map(lambda x: x.cc_num).collect()

In [17]:
cc_num_sample = random.sample(cc_num_list, NUM_RANDOM_SAMPLES)
print(cc_num_sample)

['4730351302648825', '4557873358799582', '4436369522230759', '4629251671616490', '4827760497563827', '4859223670509563', '4776202749026163', '4456277494923688', '4134120990603661', '4953391840037198', '4250799428404202', '4853206196105715', '4185260096189427', '4275108997597522', '4135894046934556', '4997434680860069', '4704235247099792', '4580063469039042', '4270322534709178', '4364372509439829', '4193799047482548', '4936052901371668', '4543292577127706', '4883290698301356', '4946069718963279', '4462604072071071', '4460285888258185', '4677777148375543', '4001629714607162', '4259787421397516', '4784109640131705', '4242134327115363', '4417214091284428', '4646554697578571', '4084987407985604', '4615861341152316', '4327251504706422', '4666545769123077', '4931602919413376', '4637117798226556', '4784755052999284', '4094026775370087', '4125310868424651', '4411504524619489', '4864410677210781', '4663001507344834', '4764837772829217', '4153347371820756', '4104686603950695', '4666301057221474',

In [18]:
# Build list of faked tuples of cc_num and timestamp
start = datetime.datetime.strptime('2021-01-31 00:00:00', '%Y-%m-%d %H:%M:%S')
end = datetime.datetime.strptime('2021-01-31 23:59:59', '%Y-%m-%d %H:%M:%S')

samples = list()
for r in range(NUM_RANDOM_SAMPLES):
    row = []
    fake_timestamp = faker.date_time_between(start_date=start, end_date=end, tzinfo=None).strftime('%Y-%m-%d %H:%M:%S')
    row.append(cc_num_sample[r])
    row.append(fake_timestamp)
    samples.append(row)
    
print(samples)

[['4730351302648825', '2021-01-31 10:32:47'], ['4557873358799582', '2021-01-31 01:04:16'], ['4436369522230759', '2021-01-31 06:21:32'], ['4629251671616490', '2021-01-31 00:04:52'], ['4827760497563827', '2021-01-31 02:48:19'], ['4859223670509563', '2021-01-31 01:51:26'], ['4776202749026163', '2021-01-31 09:46:35'], ['4456277494923688', '2021-01-31 01:02:03'], ['4134120990603661', '2021-01-31 04:14:16'], ['4953391840037198', '2021-01-31 08:20:18'], ['4250799428404202', '2021-01-31 23:15:50'], ['4853206196105715', '2021-01-31 17:44:46'], ['4185260096189427', '2021-01-31 06:46:33'], ['4275108997597522', '2021-01-31 00:53:01'], ['4135894046934556', '2021-01-31 19:16:06'], ['4997434680860069', '2021-01-31 04:41:42'], ['4704235247099792', '2021-01-31 13:44:19'], ['4580063469039042', '2021-01-31 23:32:52'], ['4270322534709178', '2021-01-31 13:28:10'], ['4364372509439829', '2021-01-31 01:33:18'], ['4193799047482548', '2021-01-31 06:34:08'], ['4936052901371668', '2021-01-31 02:12:53'], ['4543292

In [19]:
# Create and show the query DF (e.g. entity_dataframe)
query_df_schema = StructType([
    # change to transactionid (tid)
    StructField('cc_num', StringType(), False),
    StructField('joindate', StringType(), False)
])

In [20]:
# Create entity data frame

query_df = spark.createDataFrame(samples, query_df_schema)
query_df.show()

+----------------+-------------------+
|          cc_num|           joindate|
+----------------+-------------------+
|4730351302648825|2021-01-31 10:32:47|
|4557873358799582|2021-01-31 01:04:16|
|4436369522230759|2021-01-31 06:21:32|
|4629251671616490|2021-01-31 00:04:52|
|4827760497563827|2021-01-31 02:48:19|
|4859223670509563|2021-01-31 01:51:26|
|4776202749026163|2021-01-31 09:46:35|
|4456277494923688|2021-01-31 01:02:03|
|4134120990603661|2021-01-31 04:14:16|
|4953391840037198|2021-01-31 08:20:18|
|4250799428404202|2021-01-31 23:15:50|
|4853206196105715|2021-01-31 17:44:46|
|4185260096189427|2021-01-31 06:46:33|
|4275108997597522|2021-01-31 00:53:01|
|4135894046934556|2021-01-31 19:16:06|
|4997434680860069|2021-01-31 04:41:42|
|4704235247099792|2021-01-31 13:44:19|
|4580063469039042|2021-01-31 23:32:52|
|4270322534709178|2021-01-31 13:28:10|
|4364372509439829|2021-01-31 01:33:18|
+----------------+-------------------+
only showing top 20 rows



In [None]:
# entity data frame

# add multiple instances for same cc_num
#query_df = spark.createDataFrame([
#      # instead of tid we use cc_num
#      [cc_num_sample[0], "2021-01-31T18:00:00Z"], 
#      [cc_num_sample[1], "2021-01-31T14:00:00Z"],
#      [cc_num_sample[2], "2021-01-31T12:00:00Z"],
#      [cc_num_sample[3], "2021-01-31T17:00:00Z"],
#      [cc_num_sample[4], "2021-01-31T13:00:00Z"]
#    ],
#    query_df_schema)

#query_df.show()

In [21]:
# Compute min and max times over our query data for filtering, in one pass for performance

# query_df used to define bounded time window
minmax_time = query_df.agg(sql_min("joindate"), sql_max("joindate")).collect()
print(minmax_time)

[Row(min(joindate)='2021-01-31 00:04:52', max(joindate)='2021-01-31 23:32:52')]


In [22]:
min_time, max_time = minmax_time[0]["min(joindate)"], minmax_time[0]["max(joindate)"]
print(f'min_time: {min_time}')
print(f'max_time: {max_time}')

min_time: 2021-01-31 00:04:52
max_time: 2021-01-31 23:32:52


In [23]:
#print("Before filter, count: " + str(events_window.count()))
print("Before filter, count: " + str(enhanced_df.count()))

Before filter, count: 500001


In [None]:
# Filter deleted records out
#events_window = events_window.filter(~events_window.is_deleted)
#print("After count: " + str(events_window.count()))

In [24]:
%%time

# Filter out records from after highest query time and before staleness window of the min_time
# This is a performance optimization; doing this prior to individual (shopper, query_time) filtering will be faster

# 'datediff' function will convert delta to be datetime-compare compatible
allowed_staleness_days = 4

filtered = enhanced_df.filter(
    # datediff ( enddate, startdate ) - returns days
    (datediff(enhanced_df.event_time, lit(min_time)) <= allowed_staleness_days)
    & (enhanced_df.event_time <= max_time)
)

filtered.printSchema()
print("After filter, count: " + str(filtered.count()))

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)

After filter, count: 728
CPU times: user 0 ns, sys: 3.15 ms, total: 3.15 ms
Wall time: 894 ms


In [25]:
filtered.select("cc_num").show(10)

+----------------+
|          cc_num|
+----------------+
|4028853934607849|
|4047269399322294|
|4064963388466975|
|4171133180677795|
|4364372509439829|
|4376838888917748|
|4387164162852064|
|4437529857770948|
|4447228755741220|
|4453776103531090|
+----------------+
only showing top 10 rows



In [26]:
# Join with query set; drop duplicate id field
# .drop(feature_store_df.tid)

joined = filtered.join(query_df, filtered.cc_num == query_df.cc_num, "inner").drop(query_df.cc_num)
print("Joined count: " + str(joined.count()))

Joined count: 70


In [27]:
joined.show()

+-------+-----------+------------------+------------------+--------------------+--------------------+--------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
| amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|          amt_ratio1|          amt_ratio2|         count_ratio|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|           joindate|
+-------+-----------+------------------+------------------+--------------------+--------------------+--------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
|  10.85|          0|                

In [28]:
# Filter out data from after query time or before query time minus staleness window
drop_future_and_stale = joined.filter(
    (joined.event_time <= query_df.joindate)
    & (datediff(query_df.joindate, joined.event_time) <= allowed_staleness_days)
)
print("After drop stale, count: " + str(drop_future_and_stale.count()))

After drop stale, count: 6


In [29]:
drop_future_and_stale.show()

+------+-----------+------------------+------------------+-------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
|amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|         amt_ratio1|         amt_ratio2|        count_ratio|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|           joindate|
+------+-----------+------------------+------------------+-------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
| 86.37|          0|                 1|          

In [30]:
# Group by id and take latest record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.tid, x))  # to RDD with KVPs so we can use efficient reduceByKey
    .reduceByKey(
        lambda x, y: x if ((x.event_time, x.api_invocation_time) >= (y.event_time, y.api_invocation_time)) else y
    )  # Use API invocation time as tie-breaker
    .values()  # drop keys
)

In [31]:
# Convert to DF
latest_df = take_latest.toDF(drop_future_and_stale.schema)

In [32]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
selected = latest_df.drop(*columns_to_drop)

In [33]:
# Show query result
selected.show()

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# selected.write.parquet(OUTPUT_PATH, mode="overwrite")

+------+-----------+------------------+------------------+-------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+-------------------+
|amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|         amt_ratio1|         amt_ratio2|        count_ratio|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|           joindate|
+------+-----------+------------------+------------------+-------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+-------------------+
| 45.07|          0|                 2|59.760000000000005|0.16847433682650914|0.12706054820566878| 0.1111111111111111|63cbb095f913e596c...|4580063469039042|               18|         354.71|2021-01-31 22:52:22|1.618490015E9|2021-01-31 23:32:52|
| 34.17|          0|