### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries to implement Time Travel
The following notebook builds out a set of Dataframes and Spark queries that provide a pattern for using "Time Travel" capabilities that leverage SageMaker Feature Store. We will demonstrate how to build a point-in-time feature sets. Techniques include building Spark Dataframes and using cross-table joins and query filters to reduce the dataset. 

In [None]:
!pip install Faker

In [1]:
# Faker
from faker import Faker

# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())

#print(f"Spark classpath_jars: {classpath}")

In [2]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .getOrCreate())


In [3]:
sc = spark.sparkContext
print(sc.version)

2.3.4


In [4]:
SEED = 123456
faker = Faker()
faker.seed_locale('en_US', 0)
faker.seed_instance(SEED)

In [5]:
import sagemaker

BUCKET = sagemaker.Session().default_bucket()
print(BUCKET)

sagemaker-us-east-1-572539092864


In [6]:
import os

BASE_PREFIX = "sagemaker-featurestore-demo"
OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'
print(OFFLINE_STORE_BASE_URI)

AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')
print(AGG_PREFIX)

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"

s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo
sagemaker-featurestore-demo/aggregated


In [7]:
from sagemaker.s3 import S3Downloader

file_list = S3Downloader.list(AGG_FEATURES_PATH_S3)

print(f'Using S3 path: {AGG_FEATURES_PATH_S3}')
print("Found files: \n" + "\n".join(file_list))

Using S3 path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/
Found files: 
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/_SUCCESS
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/aggregated/part-00000-8516931f-40c6-4755-b0a2-79f831d24d50-c000.csv


In [8]:
transactions_df = spark.read.options(Header=True).csv(AGG_FEATURES_PATH_PARQUET)

In [9]:
transactions_df.printSchema()
transactions_df.count()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- cc_num: string (nullable = true)
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- num_trans_last_1d: string (nullable = true)
 |-- avg_amt_last_1d: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)



500001

In [10]:
transactions_df.show(5)

+--------------------+--------------------+----------------+------+-----------+------------------+----------------+-----------------+------------------+-------------------+-------------------+------------------+
|                 tid|          event_time|          cc_num|amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|num_trans_last_1d|   avg_amt_last_1d|         amt_ratio1|         amt_ratio2|       count_ratio|
+--------------------+--------------------+----------------+------+-----------+------------------+----------------+-----------------+------------------+-------------------+-------------------+------------------+
|9f2457e8699918753...|2021-01-01T13:18:...|4006080197832643|  60.0|          0|                 1|            60.0|                1|              60.0|                1.0|                1.0|               1.0|
|ea6201607ef021510...|2021-01-02T06:43:...|4006080197832643|   1.8|          0|                 1|             1.8|                2|              30.9|

#### Use Sagemaker Client to retrieve info about Feature Group
We will use the `describe_feature_group` method to lookup the S3 Uri location of the Offline Store data files.

In [11]:
from sagemaker import get_execution_role
import sagemaker
import boto3

role = get_execution_role()
sm_client = boto3.Session().client(service_name='sagemaker')
#feature_store_client = boto3.client(service_name='sagemaker-featurestore-runtime')

In [12]:
FEATURE_GROUP = 'cc-agg-batch-fg'

response = sm_client.describe_feature_group(FeatureGroupName=FEATURE_GROUP)
print (response)

{'FeatureGroupArn': 'arn:aws:sagemaker:us-east-1:572539092864:feature-group/cc-agg-batch-fg', 'FeatureGroupName': 'cc-agg-batch-fg', 'RecordIdentifierFeatureName': 'cc_num', 'EventTimeFeatureName': 'trans_time', 'FeatureDefinitions': [{'FeatureName': 'tid', 'FeatureType': 'String'}, {'FeatureName': 'cc_num', 'FeatureType': 'Integral'}, {'FeatureName': 'num_trans_last_1d', 'FeatureType': 'Integral'}, {'FeatureName': 'avg_amt_last_1d', 'FeatureType': 'Fractional'}, {'FeatureName': 'event_time', 'FeatureType': 'String'}, {'FeatureName': 'trans_time', 'FeatureType': 'Fractional'}], 'CreationTime': datetime.datetime(2021, 4, 15, 12, 27, 20, 670000, tzinfo=tzlocal()), 'OnlineStoreConfig': {'EnableOnlineStore': True}, 'OfflineStoreConfig': {'S3StorageConfig': {'S3Uri': 's3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo'}, 'DisableGlueTableCreation': False, 'DataCatalogConfig': {'TableName': 'cc-agg-batch-fg-1618489640', 'Catalog': 'AwsDataCatalog', 'Database': 'sagemaker_featu

In [13]:
# Lookup S3 Location of Offline Store

offline_store_base_s3uri = response['OfflineStoreConfig']['S3StorageConfig']['S3Uri']
offline_store_full_s3uri = (offline_store_base_s3uri + 
    '/572539092864/sagemaker/us-east-1/offline-store' +
    '/cc-agg-batch-fg-1618489640' + 
    '/data/')

print (offline_store_full_s3uri)

# TODO: replace 's3' with 's3a'

s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618489640/data/


In [14]:
# Load S3 location for Offline Store 
OFFLINE_STORE_URI = "s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618489640/data/"
print (OFFLINE_STORE_URI)

s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-demo/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1618489640/data/


In [15]:
# Read Offline Store data
feature_store_df = spark.read.parquet(OFFLINE_STORE_URI)
feature_store_df.printSchema()
feature_store_df.count()

root
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



10001

In [16]:
feature_store_df.show(5)

+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|                 tid|          cc_num|num_trans_last_1d|avg_amt_last_1d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+--------------------+----------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|8f3aac5ed6a43a632...|4006080197832643|               10|         164.94|2021-01-31 10:29:33|1.619115105E9|2021-04-22 18:16:...|2021-04-22 18:11:44|     false|2021|    4| 22|  18|
|137c5d260d041c794...|4008569092490794|               11|         327.33|2021-01-31 23:32:13|1.619115105E9|2021-04-22 18:16:...|2021-04-22 18:11:45|     false|2021|    4| 22|  18|
|4a4b1df6b77f84f4e...|4015965906982664|               21|        1014.54|2021-01-30 00:03:31|1.61911

In [17]:
# Join the raw transactons table to the aggregate feature table 

enhanced_df = (transactions_df.join(feature_store_df, transactions_df.tid == feature_store_df.tid, "left_outer")
    .drop(transactions_df.tid)
    .drop(transactions_df.cc_num)
    .drop(transactions_df.num_trans_last_1d)
    .drop(transactions_df.avg_amt_last_1d)
    .drop(transactions_df.event_time))

In [None]:
enhanced_df.printSchema()
enhanced_df.count()

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



In [None]:
# Early cutoff to prune extra data
#import calendar

# create cutoff time
#cutoff_time = datetime.datetime(2021, 2, 1)
#cutoff_cal = calendar.timegm(cutoff_time.timetuple())

#print(f'Cutoff time: {cutoff_time}')
#print(f'Cutoff cal:  {cutoff_cal}')

### Sample Time Travel query from Studio

##### Simplified query for datasets up to 100 GB

SELECT *
FROM
    (SELECT *,
         row_number()
        OVER (PARTITION BY EventTime
    ORDER BY  EventTime desc, Api_Invocation_Time DESC, write_time DESC) AS row_number
    FROM sagemaker_featurestore.identity-feature-group-03-20-32-44-1614803787
    where EventTime <= timestamp '<timestamp>')
    -- replace timestamp '<timestamp>' with just <timestamp>  if EventTimeFeature is of type fractional
WHERE row_number = 1 and
NOT is_deleted

### Build Query Dataframe that spans the intended time window

In [None]:
# entity data frame

# add multiple instances for same cc_num
#query_df = spark.createDataFrame([
#      # instead of tid we use cc_num
#      [cc_num_sample[0], "2021-01-31T18:00:00Z"], 
#      [cc_num_sample[1], "2021-01-31T14:00:00Z"],
#      [cc_num_sample[2], "2021-01-31T12:00:00Z"],
#      [cc_num_sample[3], "2021-01-31T17:00:00Z"],
#      [cc_num_sample[4], "2021-01-31T13:00:00Z"]
#    ],
#    query_df_schema)

#query_df.show()

In [None]:
import random

# Num samples in entity dataframe
NUM_RANDOM_SAMPLES = 500

cc_num_list = transactions_df.rdd.map(lambda x: x.cc_num).collect()

In [None]:
cc_num_sample = random.sample(cc_num_list, NUM_RANDOM_SAMPLES)
#print(cc_num_sample)

In [None]:
# Build list of faked tuples of cc_num and timestamp
start = datetime.datetime.strptime('2021-01-31 00:00:00', '%Y-%m-%d %H:%M:%S')
end = datetime.datetime.strptime('2021-01-31 23:00:00', '%Y-%m-%d %H:%M:%S')

samples = list()
for r in range(NUM_RANDOM_SAMPLES):
    row = []
    fake_timestamp = faker.date_time_between(start_date=start, end_date=end, tzinfo=None).strftime('%Y-%m-%d %H:00:00')
    row.append(cc_num_sample[r])
    row.append(fake_timestamp)
    samples.append(row)
    
#print(samples)

In [None]:
# Create and show the query DF (e.g. entity_dataframe)

query_df_schema = StructType([
    # change to transactionid (tid)
    StructField('cc_num', StringType(), False),
    StructField('joindate', StringType(), False)
])

In [None]:
# Create entity data frame

query_df = spark.createDataFrame(samples, query_df_schema)
query_df.show()

In [None]:
# Performance Improvement: 
# Compute min and max times over our query data for filtering, in one pass for performance

# query_df used to define bounded time window
minmax_time = query_df.agg(sql_min("joindate"), sql_max("joindate")).collect()
print(minmax_time)

In [None]:
min_time, max_time = minmax_time[0]["min(joindate)"], minmax_time[0]["max(joindate)"]
print(f'min_time: {min_time}')
print(f'max_time: {max_time}')

In [None]:
print("Before filter, count: " + str(enhanced_df.count()))

In [None]:
# Filter deleted records out
#events_window = events_window.filter(~events_window.is_deleted)
#print("After count: " + str(events_window.count()))

In [None]:
%%time

# Filter out records from after query max_time and before staleness window prior to the min_time
# This is a performance optimization; doing this prior to individual {shopper, query_time} filtering will be faster

# Choose a window of time before which we want to ignore records
allowed_staleness_days = 4

# Eliminate CC's (customers) who do NOT have any relevant records within a near timeframe (4 days)
filtered = enhanced_df.filter(
    # datediff ( enddate, startdate ) - returns days
    # we are actually removing items when event_time is MORE than 4 days before min_time (outside our buffer)
    (datediff(lit(min_time), enhanced_df.event_time) <= allowed_staleness_days) &
    (enhanced_df.event_time <= max_time)
)

In [None]:
filtered.printSchema()
print("After filter, count: " + str(filtered.count()))

In [None]:
filtered.show(10)

In [None]:
filtered.select("cc_num").show(10)

In [None]:
# Join with query set; drop duplicate cc_num field

joined = filtered.join(query_df, filtered.cc_num == query_df.cc_num, "inner").drop(query_df.cc_num)
print("Joined count: " + str(joined.count()))

In [None]:
joined.show(5)

In [None]:
# Filter out data from after query time or before query time minus staleness window
# this query removes events outside the time window FOR the SPECIFIC CC (customer)
drop_future_and_stale = joined.filter(
    (joined.event_time <= query_df.joindate)
    & (datediff(query_df.joindate, joined.event_time) <= allowed_staleness_days)
)
print("After drop stale, count: " + str(drop_future_and_stale.count()))

In [None]:
drop_future_and_stale.show(10)

In [None]:
# Group by id and take latest record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.cc_num, x))  # to RDD with KVPs so we can use efficient reduceByKey
    .reduceByKey(
        lambda x, y: x if ((x.event_time) >= (y.event_time)) else y
    )  # Use API invocation time as tie-breaker
    .values()  # drop keys
)

In [None]:
# Convert to DF
latest_df = take_latest.toDF(drop_future_and_stale.schema)

In [None]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
selected = latest_df.drop(*columns_to_drop)
print('Final selected count: ' + str(selected.count()))

In [None]:
# Show query result

selected.show(10)

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# selected.write.parquet(OUTPUT_PATH, mode="overwrite")