### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries to implement Time Travel
The following notebook builds out a set of Dataframes and Spark queries that provide a pattern for using "Time Travel" capabilities that leverage SageMaker Feature Store. We will demonstrate how to build a point-in-time feature sets. Techniques include building Spark Dataframes and using cross-table joins and query filters to reduce the dataset. 

In [None]:
!pip install Faker

In [1]:
# Faker
from faker import Faker

# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())

#print(f"Spark classpath_jars: {classpath}")

In [2]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .getOrCreate())


In [3]:
sc = spark.sparkContext
print(sc.version)

2.3.4


In [4]:
SEED = 123456
faker = Faker()
faker.seed_locale('en_US', 0)
faker.seed_instance(SEED)

In [5]:
import sagemaker

BUCKET = sagemaker.Session().default_bucket()
print(BUCKET)

sagemaker-us-east-1-572539092864


In [6]:
import os

BASE_PREFIX = "sagemaker-featurestore-demo"
OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'
print(OFFLINE_STORE_BASE_URI)

AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')
print(AGG_PREFIX)

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"

s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo
sagemaker-featurestore-demo/aggregated


In [7]:
from sagemaker.s3 import S3Downloader

file_list = S3Downloader.list(AGG_FEATURES_PATH_S3)

print(f'Using S3 path: {AGG_FEATURES_PATH_S3}')
print("Found files: \n" + "\n".join(file_list))

Using S3 path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/
Found files: 
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/_SUCCESS
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/part-00000-8905a404-2acc-42b8-9f77-bb12fa21ba4e-c000.csv


In [8]:
transactions_df = spark.read.options(Header=True).csv(AGG_FEATURES_PATH_PARQUET)

In [9]:
transactions_df.printSchema()
transactions_df.count()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- cc_num: string (nullable = true)
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- num_trans_last_1d: string (nullable = true)
 |-- avg_amt_last_1d: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)



500001

In [10]:
transactions_df.show(5)

+--------------------+--------------------+----------------+-------+-----------+------------------+----------------+-----------------+------------------+--------------------+--------------------+------------------+
|                 tid|          event_time|          cc_num| amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|num_trans_last_1d|   avg_amt_last_1d|          amt_ratio1|          amt_ratio2|       count_ratio|
+--------------------+--------------------+----------------+-------+-----------+------------------+----------------+-----------------+------------------+--------------------+--------------------+------------------+
|ec9e8c973eabb24ff...|2021-01-01T02:10:...|4028853934607849|6698.67|          0|                 1|         6698.67|                1|           6698.67|                 1.0|                 1.0|               1.0|
|6c708b2bea5bec096...|2021-01-01T03:11:...|4028853934607849| 338.97|          0|                 1|          338.97|                2|      

In [11]:
# Load S3 location for Offline Store 
OFFLINE_STORE_URI = "s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618489640/data/"


In [12]:
# Read Offline Store data
feature_store_df = spark.read.parquet(OFFLINE_STORE_URI)
feature_store_df.printSchema()
feature_store_df.count()

root
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



1000

In [13]:
feature_store_df.show(5)

+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|6c7825ae8dde6f430...|4047269399322294|               20|         929.34|2021-01-31 23:32:29|1.618490011E9|2021-04-15 12:38:...|2021-04-15 12:33:30|     false|2021|    4| 15|  12|
|298e2cab41aed7dbf...|4387164162852064|               16|         287.72|2021-01-31 23:27:24|1.618490011E9|2021-04-15 12:38:...|2021-04-15 12:33:30|     false|2021|    4| 15|  12|
|065076eddd6bb8079...|4447228755741220|               11|         307.44|2021-01-31 22:03:52|1.61849

In [14]:
# Join the raw transactons table to the aggregate feature table 

enhanced_df = (transactions_df.join(feature_store_df, transactions_df.tid == feature_store_df.tid, "left_outer")
    .drop(transactions_df.tid)
    .drop(transactions_df.cc_num)
    .drop(transactions_df.num_trans_last_1d)
    .drop(transactions_df.avg_amt_last_1d)
    .drop(transactions_df.event_time))

enhanced_df.printSchema()
enhanced_df.count()

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



500001

In [15]:
# Early cutoff to prune extra data
import calendar

# create cutoff time
cutoff_time = datetime.datetime(2021, 2, 1)
cutoff_cal = calendar.timegm(cutoff_time.timetuple())

print(f'Cutoff time: {cutoff_time}')
print(f'Cutoff cal:  {cutoff_cal}')

Cutoff time: 2021-02-01 00:00:00
Cutoff cal:  1612137600


### Sample Time Travel query from Studio

##### Simplified query for datasets up to 100 GB

SELECT *
FROM
    (SELECT *,
         row_number()
        OVER (PARTITION BY EventTime
    ORDER BY  EventTime desc, Api_Invocation_Time DESC, write_time DESC) AS row_number
    FROM sagemaker_featurestore.identity-feature-group-03-20-32-44-1614803787
    where EventTime <= timestamp '<timestamp>')
    -- replace timestamp '<timestamp>' with just <timestamp>  if EventTimeFeature is of type fractional
WHERE row_number = 1 and
NOT is_deleted

### Build Query Dataframe that spans the intended time window

In [None]:
# entity data frame

# add multiple instances for same cc_num
#query_df = spark.createDataFrame([
#      # instead of tid we use cc_num
#      [cc_num_sample[0], "2021-01-31T18:00:00Z"], 
#      [cc_num_sample[1], "2021-01-31T14:00:00Z"],
#      [cc_num_sample[2], "2021-01-31T12:00:00Z"],
#      [cc_num_sample[3], "2021-01-31T17:00:00Z"],
#      [cc_num_sample[4], "2021-01-31T13:00:00Z"]
#    ],
#    query_df_schema)

#query_df.show()

In [16]:
import random

# Num samples in entity dataframe
NUM_RANDOM_SAMPLES = 100

cc_num_list = transactions_df.rdd.map(lambda x: x.cc_num).collect()

In [17]:
cc_num_sample = random.sample(cc_num_list, NUM_RANDOM_SAMPLES)
print(cc_num_sample)

['4391394388523066', '4178703341147517', '4671096685272336', '4683617042712171', '4739672684659184', '4857743511793169', '4646120214659300', '4255558590192549', '4814121088764768', '4524584153018280', '4728806894593691', '4301684712758167', '4587576193681220', '4236603585689879', '4106727807825537', '4193799047482548', '4858014024523850', '4092336170988751', '4843834799716530', '4866429625756909', '4219457262138911', '4663671103936298', '4766960253027248', '4335550895610436', '4750764047485706', '4937909039675022', '4566212050718302', '4596231937265919', '4992021332110723', '4395886393337008', '4784109640131705', '4053728860399256', '4525813189752391', '4427066677260828', '4543145054222624', '4802445864928337', '4853952111585750', '4584401678600624', '4584998335559473', '4560154420940933', '4790550204645725', '4550792066895267', '4104825398215019', '4171133180677795', '4596231937265919', '4112847194141777', '4087089000407784', '4585047292159759', '4639389658139439', '4039621597854363',

In [18]:
# Build list of faked tuples of cc_num and timestamp
start = datetime.datetime.strptime('2021-01-31 00:00:00', '%Y-%m-%d %H:%M:%S')
end = datetime.datetime.strptime('2021-01-31 23:59:59', '%Y-%m-%d %H:%M:%S')

samples = list()
for r in range(NUM_RANDOM_SAMPLES):
    row = []
    fake_timestamp = faker.date_time_between(start_date=start, end_date=end, tzinfo=None).strftime('%Y-%m-%d %H:%M:%S')
    row.append(cc_num_sample[r])
    row.append(fake_timestamp)
    samples.append(row)
    
print(samples)

[['4391394388523066', '2021-01-31 10:32:47'], ['4178703341147517', '2021-01-31 01:04:16'], ['4671096685272336', '2021-01-31 06:21:32'], ['4683617042712171', '2021-01-31 00:04:52'], ['4739672684659184', '2021-01-31 02:48:19'], ['4857743511793169', '2021-01-31 01:51:26'], ['4646120214659300', '2021-01-31 09:46:35'], ['4255558590192549', '2021-01-31 01:02:03'], ['4814121088764768', '2021-01-31 04:14:16'], ['4524584153018280', '2021-01-31 08:20:18'], ['4728806894593691', '2021-01-31 23:15:50'], ['4301684712758167', '2021-01-31 17:44:46'], ['4587576193681220', '2021-01-31 06:46:33'], ['4236603585689879', '2021-01-31 00:53:01'], ['4106727807825537', '2021-01-31 19:16:06'], ['4193799047482548', '2021-01-31 04:41:42'], ['4858014024523850', '2021-01-31 13:44:19'], ['4092336170988751', '2021-01-31 23:32:52'], ['4843834799716530', '2021-01-31 13:28:10'], ['4866429625756909', '2021-01-31 01:33:18'], ['4219457262138911', '2021-01-31 06:34:08'], ['4663671103936298', '2021-01-31 02:12:53'], ['4766960

In [19]:
# Create and show the query DF (e.g. entity_dataframe)

query_df_schema = StructType([
    # change to transactionid (tid)
    StructField('cc_num', StringType(), False),
    StructField('joindate', StringType(), False)
])

In [20]:
# Create entity data frame

query_df = spark.createDataFrame(samples, query_df_schema)
query_df.show()

+----------------+-------------------+
|          cc_num|           joindate|
+----------------+-------------------+
|4391394388523066|2021-01-31 10:32:47|
|4178703341147517|2021-01-31 01:04:16|
|4671096685272336|2021-01-31 06:21:32|
|4683617042712171|2021-01-31 00:04:52|
|4739672684659184|2021-01-31 02:48:19|
|4857743511793169|2021-01-31 01:51:26|
|4646120214659300|2021-01-31 09:46:35|
|4255558590192549|2021-01-31 01:02:03|
|4814121088764768|2021-01-31 04:14:16|
|4524584153018280|2021-01-31 08:20:18|
|4728806894593691|2021-01-31 23:15:50|
|4301684712758167|2021-01-31 17:44:46|
|4587576193681220|2021-01-31 06:46:33|
|4236603585689879|2021-01-31 00:53:01|
|4106727807825537|2021-01-31 19:16:06|
|4193799047482548|2021-01-31 04:41:42|
|4858014024523850|2021-01-31 13:44:19|
|4092336170988751|2021-01-31 23:32:52|
|4843834799716530|2021-01-31 13:28:10|
|4866429625756909|2021-01-31 01:33:18|
+----------------+-------------------+
only showing top 20 rows



In [21]:
# Performance Improvement: 
# Compute min and max times over our query data for filtering, in one pass for performance

# query_df used to define bounded time window
minmax_time = query_df.agg(sql_min("joindate"), sql_max("joindate")).collect()
print(minmax_time)

[Row(min(joindate)='2021-01-31 00:04:52', max(joindate)='2021-01-31 23:32:52')]


In [22]:
min_time, max_time = minmax_time[0]["min(joindate)"], minmax_time[0]["max(joindate)"]
print(f'min_time: {min_time}')
print(f'max_time: {max_time}')

min_time: 2021-01-31 00:04:52
max_time: 2021-01-31 23:32:52


In [23]:
print("Before filter, count: " + str(enhanced_df.count()))

Before filter, count: 500001


In [None]:
# Filter deleted records out
#events_window = events_window.filter(~events_window.is_deleted)
#print("After count: " + str(events_window.count()))

In [24]:
%%time

# Filter out records from after query max_time and before staleness window prior to the min_time
# This is a performance optimization; doing this prior to individual {shopper, query_time} filtering will be faster

# Choose a window of time before which we want to ignore records
allowed_staleness_days = 4

# Eliminate CC's (customers) who do NOT have any relevant records within a near timeframe (4 days)
filtered = enhanced_df.filter(
    # datediff ( enddate, startdate ) - returns days
    # we are actually removing items when event_time is MORE than 4 days before min_time (outside our buffer)
    (datediff(lit(min_time), enhanced_df.event_time) <= allowed_staleness_days)
    & (enhanced_df.event_time <= max_time)
)

filtered.printSchema()
print("After filter, count: " + str(filtered.count()))

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)

After filter, count: 728
CPU times: user 2.8 ms, sys: 0 ns, total: 2.8 ms
Wall time: 925 ms


In [25]:
filtered.select("cc_num").show(10)

+----------------+
|          cc_num|
+----------------+
|4028853934607849|
|4047269399322294|
|4064963388466975|
|4171133180677795|
|4364372509439829|
|4376838888917748|
|4387164162852064|
|4437529857770948|
|4447228755741220|
|4453776103531090|
+----------------+
only showing top 10 rows



In [26]:
# Join with query set; drop duplicate id field
# .drop(feature_store_df.tid)

joined = filtered.join(query_df, filtered.cc_num == query_df.cc_num, "inner").drop(query_df.cc_num)
print("Joined count: " + str(joined.count()))

Joined count: 77


In [28]:
joined.show(5)

+-------+-----------+------------------+------------------+------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
| amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|        amt_ratio1|         amt_ratio2|        count_ratio|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|           joindate|
+-------+-----------+------------------+------------------+------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
|  95.08|          0|                 2|         

In [29]:
# Filter out data from after query time or before query time minus staleness window
# this query removes events outside the time window FOR the SPECIFIC CC (customer)
drop_future_and_stale = joined.filter(
    (joined.event_time <= query_df.joindate)
    & (datediff(query_df.joindate, joined.event_time) <= allowed_staleness_days)
)
print("After drop stale, count: " + str(drop_future_and_stale.count()))

After drop stale, count: 6


In [31]:
drop_future_and_stale.show(5)

+------+-----------+------------------+------------------+------------------+------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
|amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|        amt_ratio1|        amt_ratio2|        count_ratio|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|           joindate|
+------+-----------+------------------+------------------+------------------+------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
|373.86|          0|                 1|            373.

In [32]:
# Group by id and take latest record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.cc_num, x))  # to RDD with KVPs so we can use efficient reduceByKey
    .reduceByKey(
        lambda x, y: x if ((x.event_time) >= (y.event_time)) else y
    )  # Use API invocation time as tie-breaker
    .values()  # drop keys
)

In [33]:
# Convert to DF
latest_df = take_latest.toDF(drop_future_and_stale.schema)

In [34]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
selected = latest_df.drop(*columns_to_drop)

In [35]:
# Show query result
selected.show()

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# selected.write.parquet(OUTPUT_PATH, mode="overwrite")

+------+-----------+------------------+------------------+-------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+-------------------+
|amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|         amt_ratio1|         amt_ratio2|        count_ratio|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|           joindate|
+------+-----------+------------------+------------------+-------------------+-------------------+-------------------+--------------------+----------------+-----------------+---------------+-------------------+-------------+-------------------+
| 98.47|          0|                 1|             98.47| 0.0850845545942664| 0.0850845545942664|0.06666666666666667|4969b12d9509be06f...|4007144070776605|               15|        1157.32|2021-01-31 21:16:21|1.618490022E9|2021-01-31 21:54:08|
|353.62|          0|