### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries to implement Time Travel
The following notebook builds out a set of Dataframes and Spark queries that provide a pattern for using "Time Travel" capabilities. 

In [None]:
!pip install Faker

In [5]:
# Faker
from faker import Faker

# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())

#print(f"Spark classpath_jars: {classpath}")

In [6]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .getOrCreate())

##spark = SparkSession.builder.appName("PySparkApp").getOrCreate()

In [7]:
sc = spark.sparkContext
print(sc.version)

2.3.4


In [8]:
SEED = 123456
faker = Faker()
faker.seed_locale('en_US', 0)
faker.seed_instance(SEED)

In [9]:
import sagemaker

# Set feature Store s3 path
##PREFIX="572539092864/sagemaker/us-east-2/offline-store/transaction-feature-group-12-00-17-00/year=2020/month=11/day=12/hour=00"

BUCKET = sagemaker.Session().default_bucket()
print(BUCKET)

sagemaker-us-east-1-572539092864


In [10]:
import os

BASE_PREFIX = "sagemaker-featurestore-demo"
OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'

print(OFFLINE_STORE_BASE_URI)

PREFIX = os.path.join(BASE_PREFIX, 'aggregated')
print(PREFIX)

FEATURE_STORE_PATH_S3 = f"s3://{BUCKET}/{PREFIX}/"
FEATURE_STORE_PATH_PARQ = f"s3a://{BUCKET}/{PREFIX}/"

s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo
sagemaker-featurestore-demo/aggregated


In [11]:
from sagemaker.s3 import S3Downloader

file_list = S3Downloader.list(FEATURE_STORE_PATH_S3)

print(f'Using S3 path: {FEATURE_STORE_PATH_S3}')
print("Found files: \n" + "\n".join(file_list))

Using S3 path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/
Found files: 
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/_SUCCESS
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/part-00000-0e62ddd1-2489-452b-88db-d29726847041-c000.csv


In [None]:
# Read and show feature store DF
#feature_store_df = spark.read.parquet(FEATURE_STORE_PATH_PARQ)

# copy CSV file from above
##partfile = "part-00000-736d45d1-ee46-4745-b88a-7ba9460b239c-c000.csv"

##FEATURE_STORE_PATH_CSV = f"s3a://{BUCKET}/{PREFIX}/{partfile}"

In [12]:
transactions_df = spark.read.options(Header=True).csv(FEATURE_STORE_PATH_PARQ)

In [13]:
transactions_df.printSchema()
transactions_df.count()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- cc_num: string (nullable = true)
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- num_trans_last_1d: string (nullable = true)
 |-- avg_amt_last_1d: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)



500001

In [14]:
transactions_df.show(5)

+--------------------+--------------------+----------------+------+-----------+------------------+------------------+-----------------+------------------+-------------------+-------------------+------------------+
|                 tid|          event_time|          cc_num|amount|fraud_label|num_trans_last_60m|  avg_amt_last_60m|num_trans_last_1d|   avg_amt_last_1d|         amt_ratio1|         amt_ratio2|       count_ratio|
+--------------------+--------------------+----------------+------+-----------+------------------+------------------+-----------------+------------------+-------------------+-------------------+------------------+
|1df1e7c3a3d547646...|2021-01-01T00:08:...|4028853934607849| 18.68|          0|                 1|             18.68|                1|             18.68|                1.0|                1.0|               1.0|
|ef698627d0cf96b7a...|2021-01-01T01:04:...|4028853934607849|  1.09|          0|                 2|             9.885|                2|         

In [15]:
OFFLINE_STORE_URI = \
"s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618437897/data/"
#OFFLINE_STORE_URI = "s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618411941/data/"
#OFFLINE_STORE_URI = "s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618411941/data/year=2021/month=04/day=14/hour=14/"


In [16]:
##feature_store_df = spark.read.parquet(OFFLINE_STORE_BASE_URI + "/data")

feature_store_df = spark.read.parquet(OFFLINE_STORE_URI)
feature_store_df.printSchema()
feature_store_df.count()

root
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



1000

In [17]:
feature_store_df.show(5)

+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|ceb3e268f8aeed354...|4047269399322294|               15|         555.18|2021-01-31 21:48:32|1.618438279E9|2021-04-14 22:16:...|2021-04-14 22:11:19|     false|2021|    4| 14|  22|
|66f97b675966d60dd...|4387164162852064|               21|         267.87|2021-01-31 22:26:40|1.618438279E9|2021-04-14 22:16:...|2021-04-14 22:11:19|     false|2021|    4| 14|  22|
|412a567f353a83a2f...|4447228755741220|               14|         772.38|2021-01-31 23:39:09|1.61843

In [19]:
# Join the raw transactons table to the aggregate feature table 

combined_df = (transactions_df.join(feature_store_df, transactions_df.tid == feature_store_df.tid, "inner")
    .drop(transactions_df.tid)
    .drop(transactions_df.cc_num)
    .drop(transactions_df.num_trans_last_1d)
    .drop(transactions_df.avg_amt_last_1d)
    .drop(transactions_df.event_time))

combined_df.printSchema()
combined_df.count()

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



1000

In [20]:
combined_df.createOrReplaceTempView("features")

In [21]:
import calendar

# create cutoff time
cutoff_time = datetime.datetime(2021, 2, 1)
cutoff_cal = calendar.timegm(cutoff_time.timetuple())

print(f'Cutoff time: {cutoff_time}')
print(f'Cutoff cal:  {cutoff_cal}')

Cutoff time: 2021-02-01 00:00:00
Cutoff cal:  1612137600


In [22]:
# events = spark.sql("SELECT transactionid, transactionamt, transactiondt, write_time FROM features WHERE write_time < '" + str(cutoff_time) + "' ORDER BY transactiondt")
events = spark.sql("SELECT tid, event_time, trans_time, api_invocation_time, cc_num, fraud_label, amount, num_trans_last_1d, avg_amt_last_1d \
                    FROM features WHERE event_time <= '" + str(cutoff_time) + "' ORDER BY event_time")

print ("Count: " + str(events.count()))

Count: 1000


In [23]:
events.show(5)

+--------------------+-------------------+-------------+-------------------+----------------+-----------+------+-----------------+---------------+
|                 tid|         event_time|   trans_time|api_invocation_time|          cc_num|fraud_label|amount|num_trans_last_1d|avg_amt_last_1d|
+--------------------+-------------------+-------------+-------------------+----------------+-----------+------+-----------------+---------------+
|14d27169f9dfa1123...|2021-01-31 12:47:00|1.618438289E9|2021-04-14 22:11:29|4841758566240493|          0| 39.17|               19|        1067.15|
|44195c898b81a0382...|2021-01-31 12:53:39| 1.61843828E9|2021-04-14 22:11:19|4112052632105688|          0|  47.8|               16|        1232.35|
|8acb7c348354cd367...|2021-01-31 13:17:07|1.618438288E9|2021-04-14 22:11:28|4115181035084583|          0| 14.48|               23|         270.36|
|eab8602dce36b70c8...|2021-01-31 14:31:35| 1.61843829E9|2021-04-14 22:11:30|4007144070776605|          0|  1.71|      

### Sample Time Travel query from Studio

##### Simplified query for datasets up to 100 GB

SELECT *
FROM
    (SELECT *,
         row_number()
        OVER (PARTITION BY EventTime
    ORDER BY  EventTime desc, Api_Invocation_Time DESC, write_time DESC) AS row_number
    FROM sagemaker_featurestore.identity-feature-group-03-20-32-44-1614803787
    where EventTime <= timestamp '<timestamp>')
    -- replace timestamp '<timestamp>' with just <timestamp>  if EventTimeFeature is of type fractional
WHERE row_number = 1 and
NOT is_deleted

In [24]:
import random

NUM_RANDOM_SAMPLES = 100

cc_num_list = transactions_df.rdd.map(lambda x: x.cc_num).collect()

In [25]:
cc_num_sample = random.sample(cc_num_list, NUM_RANDOM_SAMPLES)
print(cc_num_sample)

['4366675155603615', '4566212050718302', '4230880984727064', '4526580585599952', '4563254037187733', '4240943081010350', '4783713328425247', '4184024195705504', '4039951234037985', '4390047232225939', '4609751074905238', '4656683571012535', '4730351302648825', '4219457262138911', '4159574586922804', '4897510052779978', '4775216567122137', '4411504524619489', '4086324433357673', '4948841722803552', '4408878958981127', '4119215696686652', '4372082122888803', '4823810986511227', '4893308344742860', '4702407853825297', '4691298379888791', '4538154569785056', '4125580628518162', '4943697132856191', '4863832984712996', '4303127491121668', '4236603585689879', '4994055031069549', '4574116219417534', '4025129427215983', '4544141497443538', '4422842854172126', '4033160065800364', '4566129671485170', '4459780215953485', '4937809707561266', '4116370669391070', '4211230960431815', '4557873358799582', '4916419000645975', '4199439111597863', '4108484702041863', '4343189075029008', '4709600761304131',

In [33]:
# Build list of faked tuples of cc_num and timestamp
start = datetime.datetime.strptime('2021-01-31 12:00:00', '%Y-%m-%d %H:%M:%S')
end = datetime.datetime.strptime('2021-01-31 23:59:59', '%Y-%m-%d %H:%M:%S')

samples = list()
for r in range(NUM_RANDOM_SAMPLES):
    row = []
    fake_timestamp = faker.date_time_between(start_date=start, end_date=end, tzinfo=None).strftime('%Y-%m-%d %H:%M:%S')
    row.append(cc_num_sample[r])
    row.append(fake_timestamp)
    samples.append(row)
    
print(samples)

[['4366675155603615', '2021-01-31 21:35:49'], ['4566212050718302', '2021-01-31 16:35:20'], ['4230880984727064', '2021-01-31 22:22:56'], ['4526580585599952', '2021-01-31 23:34:29'], ['4563254037187733', '2021-01-31 23:21:36'], ['4240943081010350', '2021-01-31 18:50:58'], ['4783713328425247', '2021-01-31 13:12:02'], ['4184024195705504', '2021-01-31 19:33:34'], ['4039951234037985', '2021-01-31 19:16:16'], ['4390047232225939', '2021-01-31 17:15:27'], ['4609751074905238', '2021-01-31 15:12:30'], ['4656683571012535', '2021-01-31 21:23:50'], ['4730351302648825', '2021-01-31 23:46:43'], ['4219457262138911', '2021-01-31 23:25:54'], ['4159574586922804', '2021-01-31 18:50:00'], ['4897510052779978', '2021-01-31 19:43:57'], ['4775216567122137', '2021-01-31 12:49:06'], ['4411504524619489', '2021-01-31 14:38:17'], ['4086324433357673', '2021-01-31 19:02:32'], ['4948841722803552', '2021-01-31 14:27:53'], ['4408878958981127', '2021-01-31 19:30:28'], ['4119215696686652', '2021-01-31 14:08:32'], ['4372082

In [34]:
# Create and show the query DF (e.g. entity_dataframe)
query_df_schema = StructType([
    # change to transactionid (tid)
    StructField('cc_num', StringType(), False),
    StructField('joindate', StringType(), False)
])

In [35]:
# Create entity data frame

query_df = spark.createDataFrame(samples, query_df_schema)
query_df.show()

+----------------+-------------------+
|          cc_num|           joindate|
+----------------+-------------------+
|4366675155603615|2021-01-31 21:35:49|
|4566212050718302|2021-01-31 16:35:20|
|4230880984727064|2021-01-31 22:22:56|
|4526580585599952|2021-01-31 23:34:29|
|4563254037187733|2021-01-31 23:21:36|
|4240943081010350|2021-01-31 18:50:58|
|4783713328425247|2021-01-31 13:12:02|
|4184024195705504|2021-01-31 19:33:34|
|4039951234037985|2021-01-31 19:16:16|
|4390047232225939|2021-01-31 17:15:27|
|4609751074905238|2021-01-31 15:12:30|
|4656683571012535|2021-01-31 21:23:50|
|4730351302648825|2021-01-31 23:46:43|
|4219457262138911|2021-01-31 23:25:54|
|4159574586922804|2021-01-31 18:50:00|
|4897510052779978|2021-01-31 19:43:57|
|4775216567122137|2021-01-31 12:49:06|
|4411504524619489|2021-01-31 14:38:17|
|4086324433357673|2021-01-31 19:02:32|
|4948841722803552|2021-01-31 14:27:53|
+----------------+-------------------+
only showing top 20 rows



In [None]:
# entity data frame

# add multiple instances for same cc_num
#query_df = spark.createDataFrame([
#      # instead of tid we use cc_num
#      [cc_num_sample[0], "2021-01-31T18:00:00Z"], 
#      [cc_num_sample[1], "2021-01-31T14:00:00Z"],
#      [cc_num_sample[2], "2021-01-31T12:00:00Z"],
#      [cc_num_sample[3], "2021-01-31T17:00:00Z"],
#      [cc_num_sample[4], "2021-01-31T13:00:00Z"]
#    ],
#    query_df_schema)

# To instead load this query df from s3:
# QUERY_PATH = f"s3://{BUCKET}/{PREFIX}/test_query.parquet"
# query_df = spark.read.parquet(QUERY_PATH)

# TEST, will it still run ok with 5,000 rows in the entity frame
# try pandas sample() 

#query_df.show()

In [36]:
# Compute min and max times over our query data for filtering, in one pass for performance

# query_df used to define bounded time window
minmax_time = query_df.agg(sql_min("joindate"), sql_max("joindate")).collect()
print(minmax_time)

[Row(min(joindate)='2021-01-31 12:30:06', max(joindate)='2021-01-31 23:58:07')]


In [37]:
min_time, max_time = minmax_time[0]["min(joindate)"], minmax_time[0]["max(joindate)"]
print(f'min_time: {min_time}')
print(f'max_time: {max_time}')

min_time: 2021-01-31 12:30:06
max_time: 2021-01-31 23:58:07


In [38]:
print("Before filter, count: " + str(events.count()))

Before filter, count: 1000


In [None]:
# Filter deleted records out
#filtered = events.filter(~events.is_deleted)
#print("After count: " + str(filtered.count()))

In [39]:
# Filter out records from after highest query time and before staleness window of the min_time
# This is a performance optimization; doing this prior to individual (shopper, query_time) filtering will be faster

# 'datediff' function will convert delta to be datetime-compare compatible
allowed_staleness_days = 4

filtered = events.filter(
    # datediff ( enddate, startdate ) - returns days
    (datediff(events.event_time, lit(min_time)) <= allowed_staleness_days)
    & (events.event_time <= max_time)
)

print("After filter, count: " + str(filtered.count()))

After filter, count: 981


In [40]:
filtered.select("cc_num").show(20)

+----------------+
|          cc_num|
+----------------+
|4841758566240493|
|4112052632105688|
|4115181035084583|
|4007144070776605|
|4741003082139478|
|4661693857717130|
|4897249935555210|
|4213741526478791|
|4862530081141983|
|4056057388547395|
|4609270001711569|
|4349977314233413|
|4512227924204037|
|4221583135670743|
|4007365955573714|
|4647625239834026|
|4947638477574854|
|4585480040850249|
|4232420034941999|
|4766004007367160|
+----------------+
only showing top 20 rows



In [41]:
filtered.printSchema()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- amount: string (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)



In [42]:
# Join with query set; drop duplicate id field
# .drop(feature_store_df.tid)

joined = filtered.join(query_df, combined_df.cc_num == query_df.cc_num, "inner").drop(query_df.cc_num)
print("Joined count: " + str(joined.count()))

Joined count: 98


In [43]:
joined.show()

+--------------------+-------------------+-------------+-------------------+----------------+-----------+-------+-----------------+---------------+-------------------+
|                 tid|         event_time|   trans_time|api_invocation_time|          cc_num|fraud_label| amount|num_trans_last_1d|avg_amt_last_1d|           joindate|
+--------------------+-------------------+-------------+-------------------+----------------+-----------+-------+-----------------+---------------+-------------------+
|d22910f4cf6475542...|2021-01-31 23:15:22|1.618438283E9|2021-04-14 22:11:23|4775216567122137|          0|9047.24|               16|        1475.59|2021-01-31 12:49:06|
|be2ae1d6c75fdc576...|2021-01-31 23:28:12|1.618438288E9|2021-04-14 22:11:28|4086324433357673|          0|4667.65|               13|         903.64|2021-01-31 19:02:32|
|3833b8ff2ffa65c14...|2021-01-31 23:16:21|1.618438289E9|2021-04-14 22:11:29|4691298379888791|          0| 353.52|               16|         203.15|2021-01-31 17

In [44]:
# Filter out data from after query time or before query time minus staleness window
drop_future_and_stale = joined.filter(
    (joined.event_time <= query_df.joindate)
    & (datediff(query_df.joindate, feature_store_df.event_time) <= allowed_staleness_days)
)
print("After drop stale, count: " + str(drop_future_and_stale.count()))

After drop stale, count: 16


In [45]:
drop_future_and_stale.show()

+--------------------+-------------------+-------------+-------------------+----------------+-----------+-------+-----------------+---------------+-------------------+
|                 tid|         event_time|   trans_time|api_invocation_time|          cc_num|fraud_label| amount|num_trans_last_1d|avg_amt_last_1d|           joindate|
+--------------------+-------------------+-------------+-------------------+----------------+-----------+-------+-----------------+---------------+-------------------+
|8b7d614872d30f5b5...|2021-01-31 22:09:46|1.618438279E9|2021-04-14 22:11:19|4655251061784855|          0|   1.91|               17|         999.39|2021-01-31 22:24:36|
|fd52042037115275d...|2021-01-31 19:28:44|1.618438284E9|2021-04-14 22:11:24|4526580585599952|          0|3108.82|               22|         844.45|2021-01-31 23:34:29|
|c3fcb8a3a2093a1da...|2021-01-31 21:01:00|1.618438282E9|2021-04-14 22:11:21|4366675155603615|          0|  13.39|               17|        1245.66|2021-01-31 21

In [46]:
# Group by id and take latest record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.tid, x))  # to RDD with KVPs so we can use efficient reduceByKey
    .reduceByKey(
        lambda x, y: x if ((x.event_time, x.api_invocation_time) >= (y.event_time, y.api_invocation_time)) else y
    )  # Use API invocation time as tie-breaker
    .values()  # drop keys
)

In [47]:
# Convert to DF
latest_df = take_latest.toDF(drop_future_and_stale.schema)

In [48]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
selected = latest_df.drop(*columns_to_drop)

In [49]:
# Show query result
selected.show()

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# selected.write.parquet(OUTPUT_PATH, mode="overwrite")

+--------------------+-------------------+-------------+----------------+-----------+-------+-----------------+---------------+-------------------+
|                 tid|         event_time|   trans_time|          cc_num|fraud_label| amount|num_trans_last_1d|avg_amt_last_1d|           joindate|
+--------------------+-------------------+-------------+----------------+-----------+-------+-----------------+---------------+-------------------+
|c55d78ffcc68cafc2...|2021-01-31 21:33:49|1.618438282E9|4303127491121668|          0| 898.27|               14|        1300.84|2021-01-31 23:01:44|
|c3fcb8a3a2093a1da...|2021-01-31 21:01:00|1.618438282E9|4366675155603615|          0|  13.39|               17|        1245.66|2021-01-31 21:35:49|
|ca0ef60e326d346f1...|2021-01-31 21:38:23|1.618438286E9|4730351302648825|          0|  31.38|               16|        1227.12|2021-01-31 23:46:43|
|fd52042037115275d...|2021-01-31 19:28:44|1.618438284E9|4526580585599952|          0|3108.82|               22| 