### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries to implement Time Travel
The following notebook uses SageMaker Feature Store and Apache Spark to build out a set of Dataframes and queries that provide a pattern for using "Time Travel" capabilities. We will demonstrate how to build a "point-in-time" feature sets by starting with raw transactional data, joining that data with records from the Offline Store, and then building an "entity" dataset to define the items we care about and the timestamp of reference. Techniques include building Spark Dataframes, using outer and inner table joins, using query filters to prune items outside our timeframe, and finally ReduceByKey to reduce the final the dataset. 

#### Install Faker library to help generate timestamps within a given range

In [1]:
!pip install Faker

distutils: /home/ec2-user/anaconda3/envs/python3/include/python3.6m/UNKNOWN
sysconfig: /home/ec2-user/anaconda3/envs/python3/include/python3.6m[0m
user = False
home = None
root = None
prefix = None[0m
distutils: /home/ec2-user/anaconda3/envs/python3/include/python3.6m/UNKNOWN
sysconfig: /home/ec2-user/anaconda3/envs/python3/include/python3.6m[0m
user = False
home = None
root = None
prefix = None[0m


In [2]:
# Faker
from faker import Faker

# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime
import random

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())


In [3]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .getOrCreate())


In [4]:
sc = spark.sparkContext
print(sc.version)

2.3.4


In [5]:
SEED = 123456
faker = Faker()
faker.seed_locale('en_US', 0)
faker.seed_instance(SEED)

In [6]:
import sagemaker

BUCKET = sagemaker.Session().default_bucket()
print(BUCKET)

sagemaker-us-east-1-572539092864


In [7]:
import os

BASE_PREFIX = "sagemaker-featurestore-blog"

OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'

AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')
print(f'S3 Aggregated Prefix: {AGG_PREFIX}')

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"

S3 Aggregated Prefix: sagemaker-featurestore-blog/aggregated


In [8]:
from sagemaker.s3 import S3Downloader

file_list = S3Downloader.list(AGG_FEATURES_PATH_S3)

print(f'Using S3 path: {AGG_FEATURES_PATH_S3}')
print("Found files: \n" + "\n".join(file_list))

Using S3 path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/aggregated/
Found files: 
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/aggregated/_SUCCESS
s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/aggregated/part-00000-c991f597-2032-42ab-ae57-974543cbcf59-c000.csv


#### Let's retreive our credit card transaction data

In [9]:
transactions_df = spark.read.options(Header=True).csv(AGG_FEATURES_PATH_PARQUET)

In [10]:
transactions_df.printSchema()
transactions_df.count()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- cc_num: string (nullable = true)
 |-- consumer_id: string (nullable = true)
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- num_trans_last_7d: string (nullable = true)
 |-- avg_amt_last_7d: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)



500001

In [12]:
# Show 5 random rows from dataframe
show_fraction = float(5.0 / 500000.0)
print("Fraction: %2f" % show_fraction)
print(transactions_df.sample(withReplacement=False, fraction=show_fraction, seed=3).collect())

Fraction: 0.000010
[Row(tid='a65d2367cff1400cd245c5db5361597f', event_time='2021-01-17T17:02:54.000Z', cc_num='4172096516686385', consumer_id='XBHV44170963561062', amount='5822.98', fraud_label='0', num_trans_last_60m='1', avg_amt_last_60m='5822.98', num_trans_last_7d='15', avg_amt_last_7d='1328.2613333333331', amt_ratio1='4.383911398961651', amt_ratio2='4.383911398961651', count_ratio='0.06666666666666667'), Row(tid='0cfce750d02f3c1275b7c36d8cf7b703', event_time='2021-01-02T02:34:34.000Z', cc_num='4675446179121786', consumer_id='MXMS11417993608776', amount='1791.58', fraud_label='0', num_trans_last_60m='1', avg_amt_last_60m='1791.58', num_trans_last_7d='3', avg_amt_last_7d='622.82', amt_ratio1='2.8765614463247804', amt_ratio2='2.8765614463247804', count_ratio='0.3333333333333333'), Row(tid='cd9b000579c24ac8846bf5f8cfac5269', event_time='2021-01-11T13:30:35.000Z', cc_num='4781272421257207', consumer_id='ZKFN80055633118034', amount='37.88', fraud_label='0', num_trans_last_60m='1', avg_a

#### Use Sagemaker Client to retrieve info about Feature Group
We will use the `describe_feature_group` method to lookup the S3 Uri location of the Offline Store data files.

In [13]:
from sagemaker import get_execution_role
import sagemaker
import boto3

role = get_execution_role()
sm_client = boto3.Session().client(service_name='sagemaker')


In [14]:
# Identify name of the Feature Group that contains aggregated features for our transaction data
FEATURE_GROUP = 'cc-agg-batch-fg'

feature_group_info = sm_client.describe_feature_group(FeatureGroupName=FEATURE_GROUP)
feature_group_info

{'FeatureGroupArn': 'arn:aws:sagemaker:us-east-1:572539092864:feature-group/cc-agg-batch-fg',
 'FeatureGroupName': 'cc-agg-batch-fg',
 'RecordIdentifierFeatureName': 'consumer_id',
 'EventTimeFeatureName': 'trans_time',
 'FeatureDefinitions': [{'FeatureName': 'tid', 'FeatureType': 'String'},
  {'FeatureName': 'cc_num', 'FeatureType': 'Integral'},
  {'FeatureName': 'consumer_id', 'FeatureType': 'String'},
  {'FeatureName': 'num_trans_last_7d', 'FeatureType': 'Integral'},
  {'FeatureName': 'avg_amt_last_7d', 'FeatureType': 'Fractional'},
  {'FeatureName': 'event_time', 'FeatureType': 'String'},
  {'FeatureName': 'trans_time', 'FeatureType': 'Fractional'}],
 'CreationTime': datetime.datetime(2021, 5, 4, 0, 2, 44, 97000, tzinfo=tzlocal()),
 'OnlineStoreConfig': {'EnableOnlineStore': True},
 'OfflineStoreConfig': {'S3StorageConfig': {'S3Uri': 's3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog',
   'ResolvedOutputS3Uri': 's3://sagemaker-us-east-1-572539092864/sagemaker-featur

In [15]:
# Lookup S3 Location of Offline Store

resolved_offline_store_s3_location = feature_group_info['OfflineStoreConfig']['S3StorageConfig']['ResolvedOutputS3Uri']

# Spark's Parquet file reader requires replacement of 's3' with 's3a'
offline_store_s3a_uri = resolved_offline_store_s3_location.replace("s3:", "s3a:")

print(offline_store_s3a_uri)

s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/572539092864/sagemaker/us-east-1/offline-store/cc-agg-batch-fg-1620086564/data


In [16]:
# Read Offline Store data
feature_store_df = spark.read.parquet(offline_store_s3a_uri)

In [17]:
feature_store_df.printSchema()
feature_store_df.count()

root
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- consumer_id: string (nullable = true)
 |-- num_trans_last_7d: long (nullable = true)
 |-- avg_amt_last_7d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



10001

In [18]:
feature_store_df.show(5)

+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|                 tid|          cc_num|       consumer_id|num_trans_last_7d|avg_amt_last_7d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|c9fc14047042b8c77...|4819550644575521|AKJG26134916263543|               14|         108.39|2021-01-31 06:11:28|1.620086953E9|2021-05-04 00:14:...|2021-05-04 00:09:13|     false|2021|    5|  4|   0|
|3926227c69fcd439d...|4691951668579907|ASMP81111912057771|               14|         805.59|2021-01-31 18:44:35|1.620086953E9|2021-05-04 00:14:...|2021-05-04 00:09:13|     false|2021|    5|  4|   0|
|8d0d

#### Create an enhanced set of features by joining raw transaction data with aggregate features from the Offline Store

In [19]:
# Join the raw transactons table to the aggregate feature table 

enhanced_df = (transactions_df.join(feature_store_df, transactions_df.tid == feature_store_df.tid, "left_outer")
    .drop(transactions_df.tid)
    .drop(transactions_df.cc_num)
    .drop(transactions_df.consumer_id)
    .drop(transactions_df.num_trans_last_7d)
    .drop(transactions_df.avg_amt_last_7d)
    .drop(transactions_df.event_time))

In [20]:
enhanced_df.printSchema()
enhanced_df.count()

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- consumer_id: string (nullable = true)
 |-- num_trans_last_7d: long (nullable = true)
 |-- avg_amt_last_7d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



500001

### Sample Time Travel query from Studio

##### Simplified query for datasets up to 100 GB

SELECT *
FROM
    (SELECT *,
         row_number()
        OVER (PARTITION BY EventTime
    ORDER BY  EventTime desc, Api_Invocation_Time DESC, write_time DESC) AS row_number
    FROM sagemaker_featurestore.identity-feature-group-03-20-32-44-1614803787
    where EventTime <= timestamp '<timestamp>')
    -- replace timestamp '<timestamp>' with just <timestamp>  if EventTimeFeature is of type fractional
WHERE row_number = 1 and
NOT is_deleted

### Build Entity Dataframe that spans the intended time window

In [21]:
# Num samples in entity dataframe
NUM_RANDOM_SAMPLES = 500

cid_list = transactions_df.rdd.map(lambda x: x.consumer_id).collect()

In [22]:
cid_sample = random.sample(cid_list, NUM_RANDOM_SAMPLES) 
print(len(cid_sample))

500


In [23]:
# Build list of tuples containing consumer IDs with faked timestamps within our time window
start = datetime.datetime.strptime('2021-01-31 00:00:00', '%Y-%m-%d %H:%M:%S')
end = datetime.datetime.strptime('2021-01-31 23:00:00', '%Y-%m-%d %H:%M:%S')

samples = list()
for r in range(NUM_RANDOM_SAMPLES):
    row = []
    fake_timestamp = faker.date_time_between(start_date=start, end_date=end, tzinfo=None).strftime('%Y-%m-%d %H:00:00')
    row.append(cid_sample[r])
    row.append(fake_timestamp)
    samples.append(row)
    

In [24]:
# Create and show the Entity Dataframe
# (e.g. the dataframe that defines our set of credit card numbers and timestamps for our point-in-time queries)

entity_df_schema = StructType([
    StructField('consumer_id', StringType(), False),
    StructField('joindate', StringType(), False)
])

In [25]:
# Create entity data frame

entity_df = spark.createDataFrame(samples, entity_df_schema)
entity_df.show(10)

+------------------+-------------------+
|       consumer_id|           joindate|
+------------------+-------------------+
|EPQM42267411121685|2021-01-31 10:00:00|
|TKIN43838631624846|2021-01-31 01:00:00|
|AGKE78700875586528|2021-01-31 06:00:00|
|VANM94024305759394|2021-01-31 00:00:00|
|FFMZ51315248853491|2021-01-31 02:00:00|
|UCWI31015250828475|2021-01-31 01:00:00|
|XYFV50123039975887|2021-01-31 09:00:00|
|KOFP22214206152516|2021-01-31 01:00:00|
|KALI81416137453187|2021-01-31 04:00:00|
|YXFC62040921041604|2021-01-31 08:00:00|
+------------------+-------------------+
only showing top 10 rows



In [26]:
# Performance Improvement: 
# This first dataframe filter serves as a performance optimization to reduce the size of dataset
# We compute the overall min and max times for the initial filtering, in one pass

# entity_df used to define bounded time window
minmax_time = entity_df.agg(sql_min("joindate"), sql_max("joindate")).collect()
print(minmax_time)

[Row(min(joindate)='2021-01-31 00:00:00', max(joindate)='2021-01-31 22:00:00')]


In [27]:
min_time, max_time = minmax_time[0]["min(joindate)"], minmax_time[0]["max(joindate)"]
print(f'min_time: {min_time}')
print(f'max_time: {max_time}')

min_time: 2021-01-31 00:00:00
max_time: 2021-01-31 22:00:00


In [28]:
print("Before filter, count: " + str(enhanced_df.count()))

Before filter, count: 500001


In [29]:
%%time

# Filter out records from after query max_time and before staleness window prior to the min_time
# NOTE: This is a performance optimization; doing this prior to individual {consumer_id, joindate} filtering will be faster

# Choose a "staleness" window of time before which we want to ignore records
allowed_staleness_days = 4

# Eliminate Credit Cards (entities) who do NOT have any relevant records within our time window 
# this window represents the {max_time - min_time} delta, plus our staleness window (4 days)

# Via the staleness check, we are actually removing items when event_time is MORE than 4 days before min_time
# Usage: datediff ( enddate, startdate ) - returns days

filtered = enhanced_df.filter(
    (enhanced_df.event_time <= max_time) & 
    (datediff(lit(min_time), enhanced_df.event_time) <= allowed_staleness_days)
)

CPU times: user 0 ns, sys: 2.35 ms, total: 2.35 ms
Wall time: 47.5 ms


In [30]:
filtered.printSchema()
print("After filter, count: " + str(filtered.count()))

root
 |-- amount: string (nullable = true)
 |-- fraud_label: string (nullable = true)
 |-- num_trans_last_60m: string (nullable = true)
 |-- avg_amt_last_60m: string (nullable = true)
 |-- amt_ratio1: string (nullable = true)
 |-- amt_ratio2: string (nullable = true)
 |-- count_ratio: string (nullable = true)
 |-- tid: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- consumer_id: string (nullable = true)
 |-- num_trans_last_7d: long (nullable = true)
 |-- avg_amt_last_7d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- trans_time: double (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)

After filter, count: 8807


In [31]:
filtered.show(5)

+-------+-----------+------------------+----------------+-------------------+-------------------+-------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
| amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|         amt_ratio1|         amt_ratio2|        count_ratio|                 tid|          cc_num|       consumer_id|num_trans_last_7d|avg_amt_last_7d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|
+-------+-----------+------------------+----------------+-------------------+-------------------+-------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+
|  79.42|          0|                 1|           79.4

In [32]:
filtered.select("cc_num", "consumer_id").show(5)

+----------------+------------------+
|          cc_num|       consumer_id|
+----------------+------------------+
|4250798115116459|AGKN34752255584542|
|4527632357482179|AINP94926359368311|
|4819550644575521|AKJG26134916263543|
|4031641771363971|AOUZ04666346722161|
|4198816697198128|ARRP87120193280992|
+----------------+------------------+
only showing top 5 rows



In [33]:
# Join filtered dataframe with generated entity dataframe; drop duplicate consumer_id field

joined = filtered.join(entity_df, filtered.consumer_id == entity_df.consumer_id, "inner").drop(entity_df.consumer_id)
print("Joined count: " + str(joined.count()))

Joined count: 441


In [34]:
joined.show(5)

+------+-----------+------------------+----------------+--------------------+--------------------+-------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
|amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|          amt_ratio1|          amt_ratio2|        count_ratio|                 tid|          cc_num|       consumer_id|num_trans_last_7d|avg_amt_last_7d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|           joindate|
+------+-----------+------------------+----------------+--------------------+--------------------+-------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------

In [35]:
# Filter out data from after query time or before query time minus staleness window
# this query removes events outside the time window FOR the SPECIFIC CC (customer)
drop_future_and_stale = joined.filter(
    (joined.event_time <= entity_df.joindate)
    & (datediff(entity_df.joindate, joined.event_time) <= allowed_staleness_days)
)

print("After drop stale, count: " + str(drop_future_and_stale.count()))

After drop stale, count: 200


In [36]:
drop_future_and_stale.show(5)

+-------+-----------+------------------+----------------+--------------------+--------------------+-------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+-------------------+
| amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|          amt_ratio1|          amt_ratio2|        count_ratio|                 tid|          cc_num|       consumer_id|num_trans_last_7d|avg_amt_last_7d|         event_time|   trans_time|          write_time|api_invocation_time|is_deleted|year|month|day|hour|           joindate|
+-------+-----------+------------------+----------------+--------------------+--------------------+-------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+--------------------+-------------------+----------+----+-----+---+----+----------

In [37]:
# Use reduceByKey to group by consumer_id and keep most recent record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.consumer_id, x)) 
    .reduceByKey(
        lambda x, y: x if ((x.event_time) >= (y.event_time)) else y
    )  #  We could have used api_invocation_time as tie-breaker
    .values()  # drop keys
)


In [38]:
# Convert to DataFrame
latest_df = take_latest.toDF(drop_future_and_stale.schema)

In [39]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
final_df = latest_df.drop(*columns_to_drop)

print('Final count: ' + str(final_df.count()))

Final count: 200


In [40]:
# Show final query results

final_df.show(10)

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# final_df.write.parquet(OUTPUT_PATH, mode="overwrite")

+-------+-----------+------------------+----------------+--------------------+--------------------+--------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+-------------------+
| amount|fraud_label|num_trans_last_60m|avg_amt_last_60m|          amt_ratio1|          amt_ratio2|         count_ratio|                 tid|          cc_num|       consumer_id|num_trans_last_7d|avg_amt_last_7d|         event_time|   trans_time|           joindate|
+-------+-----------+------------------+----------------+--------------------+--------------------+--------------------+--------------------+----------------+------------------+-----------------+---------------+-------------------+-------------+-------------------+
|  20.85|          0|                 1|           20.85| 0.19102619700588197| 0.19102619700588197| 0.06666666666666667|04f21b0619064256a...|4651276579663766|IXUJ57796449588233|               15|       