### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries to implement Time Travel
The following notebook uses SageMaker Feature Store and Apache Spark to build out a set of Dataframes and queries that provide a pattern for using "Time Travel" capabilities. We will demonstrate how to build a "point-in-time" feature sets by starting with raw transactional data, joining that data with records from the Offline Store, and then building an "entity" dataset to define the items we care about and the timestamp of reference. Techniques include building Spark Dataframes, using outer and inner table joins, using query filters to prune items outside our timeframe, and finally ReduceByKey to reduce the final the dataset. 

#### Install Faker library to help generate timestamps within a given range

In [None]:
!pip install Faker

In [None]:
# Faker
from faker import Faker

# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.types import StringType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime
import random

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())


In [None]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .getOrCreate())


In [None]:
sc = spark.sparkContext
print(sc.version)

In [None]:
SEED = 123456
faker = Faker()
faker.seed_locale('en_US', 0)
faker.seed_instance(SEED)

In [None]:
import sagemaker

BUCKET = sagemaker.Session().default_bucket()
print(BUCKET)

In [None]:
import os

BASE_PREFIX = "sagemaker-featurestore-blog"

OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'

AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')
print(f'S3 Aggregated Prefix: {AGG_PREFIX}')

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"

In [None]:
from sagemaker.s3 import S3Downloader

file_list = S3Downloader.list(AGG_FEATURES_PATH_S3)

print(f'Using S3 path: {AGG_FEATURES_PATH_S3}')
print("Found files: \n" + "\n".join(file_list))

#### Let's retreive our credit card transaction data

In [None]:
transactions_df = spark.read.options(Header=True).csv(AGG_FEATURES_PATH_PARQUET)

In [None]:
transactions_df.printSchema()
transactions_df.count()

In [None]:
# Show 5 random rows from dataframe
show_fraction = float(5.0 / 500000.0)
print("Fraction: %2f" % show_fraction)
print(transactions_df.sample(withReplacement=False, fraction=show_fraction, seed=3).collect())

#### Use Sagemaker Client to retrieve info about Feature Group
We will use the `describe_feature_group` method to lookup the S3 Uri location of the Offline Store data files.

In [None]:
from sagemaker import get_execution_role
import sagemaker
import boto3

role = get_execution_role()
sm_client = boto3.Session().client(service_name='sagemaker')


In [None]:
# Identify name of the Feature Group that contains aggregated features for our transaction data
FEATURE_GROUP = 'cc-agg-batch-fg'

feature_group_info = sm_client.describe_feature_group(FeatureGroupName=FEATURE_GROUP)
feature_group_info

In [None]:
# Lookup S3 Location of Offline Store

resolved_offline_store_s3_location = feature_group_info['OfflineStoreConfig']['S3StorageConfig']['ResolvedOutputS3Uri']

# Spark's Parquet file reader requires replacement of 's3' with 's3a'
offline_store_s3a_uri = resolved_offline_store_s3_location.replace("s3:", "s3a:")

print(offline_store_s3a_uri)

In [None]:
# Read Offline Store data
feature_store_df = spark.read.parquet(offline_store_s3a_uri)

In [None]:
feature_store_df.printSchema()
feature_store_df.count()

In [None]:
feature_store_df.show(5)

#### Create an enhanced set of features by joining raw transaction data with aggregate features from the Offline Store

In [None]:
# Join the raw transactons table to the aggregate feature table 

enhanced_df = (transactions_df.join(feature_store_df, transactions_df.tid == feature_store_df.tid, "left_outer")
    .drop(transactions_df.tid)
    .drop(transactions_df.cc_num)
    .drop(transactions_df.consumer_id)
    .drop(transactions_df.num_trans_last_7d)
    .drop(transactions_df.avg_amt_last_7d)
    .drop(transactions_df.event_time))

In [None]:
enhanced_df.printSchema()
enhanced_df.count()

### Sample Time Travel query from Studio

Now that we have an enhanced dataframe with all our transaction data, we can start building the time travel query. We begin be creating an Entity Dataframe which identifies the consumer_ids of interest, coupled with an event_time which represents our cutoff time for that entity. We also define a staleness window which prevents us from using data older than some limit that we define.

In [None]:
# Num samples in entity dataframe
NUM_RANDOM_SAMPLES = 500

cid_list = transactions_df.rdd.map(lambda x: x.consumer_id).collect()

In [None]:
cid_sample = random.sample(cid_list, NUM_RANDOM_SAMPLES) 
print(len(cid_sample))

In [None]:
# Build list of tuples containing consumer IDs with faked timestamps within our time window
start = datetime.datetime.strptime('2021-01-31 00:00:00', '%Y-%m-%d %H:%M:%S')
end = datetime.datetime.strptime('2021-01-31 23:00:00', '%Y-%m-%d %H:%M:%S')

samples = list()
for r in range(NUM_RANDOM_SAMPLES):
    row = []
    fake_timestamp = faker.date_time_between(start_date=start, end_date=end, tzinfo=None).strftime('%Y-%m-%d %H:00:00')
    row.append(cid_sample[r])
    row.append(fake_timestamp)
    samples.append(row)
    

In [None]:
# Create and show the Entity Dataframe
# (e.g. the dataframe that defines our set of credit card numbers and timestamps for our point-in-time queries)

entity_df_schema = StructType([
    StructField('consumer_id', StringType(), False),
    StructField('joindate', StringType(), False)
])

In [None]:
# Create entity data frame

entity_df = spark.createDataFrame(samples, entity_df_schema)
entity_df.show(10)

In [None]:
# Performance Improvement: 
# This first dataframe filter serves as a performance optimization to reduce the size of dataset
# We compute the overall min and max times for the initial filtering, in one pass

# entity_df used to define bounded time window
minmax_time = entity_df.agg(sql_min("joindate"), sql_max("joindate")).collect()
print(minmax_time)

In [None]:
min_time, max_time = minmax_time[0]["min(joindate)"], minmax_time[0]["max(joindate)"]
print(f'min_time: {min_time}')
print(f'max_time: {max_time}')

In [None]:
print("Before filter, count: " + str(enhanced_df.count()))

In [None]:
%%time

# Filter out records from after query max_time and before staleness window prior to the min_time
# NOTE: This is a performance optimization; doing this prior to individual {consumer_id, joindate} filtering will be faster

# Choose a "staleness" window of time before which we want to ignore records
allowed_staleness_days = 4

# Eliminate Credit Cards (entities) who do NOT have any relevant records within our time window 
# this window represents the {max_time - min_time} delta, plus our staleness window (4 days)

# Via the staleness check, we are actually removing items when event_time is MORE than 4 days before min_time
# Usage: datediff ( enddate, startdate ) - returns days

filtered = enhanced_df.filter(
    (enhanced_df.event_time <= max_time) & 
    (datediff(lit(min_time), enhanced_df.event_time) <= allowed_staleness_days)
)

In [None]:
filtered.printSchema()
print("After filter, count: " + str(filtered.count()))

In [None]:
filtered.show(5)

In [None]:
filtered.select("cc_num", "consumer_id").show(5)

In [None]:
# Join filtered dataframe with generated entity dataframe; drop duplicate consumer_id field

joined = filtered.join(entity_df, filtered.consumer_id == entity_df.consumer_id, "inner").drop(entity_df.consumer_id)
print("Joined count: " + str(joined.count()))

In [None]:
joined.show(5)

In [None]:
# Filter out data from after query time or before query time minus staleness window
# this query removes events outside the time window FOR the SPECIFIC CC (customer)
drop_future_and_stale = joined.filter(
    (joined.event_time <= entity_df.joindate)
    & (datediff(entity_df.joindate, joined.event_time) <= allowed_staleness_days)
)

print("After drop stale, count: " + str(drop_future_and_stale.count()))

In [None]:
drop_future_and_stale.show(5)

In [None]:
# Use reduceByKey to group by consumer_id and keep most recent record
take_latest = (
    drop_future_and_stale.rdd.map(lambda x: (x.consumer_id, x)) 
    .reduceByKey(
        lambda x, y: x if ((x.event_time) >= (y.event_time)) else y
    )  #  We could have used api_invocation_time as tie-breaker
    .values()  # drop keys
)


In [None]:
# Convert to DataFrame
latest_df = take_latest.toDF(drop_future_and_stale.schema)

In [None]:
# Drop extra columns
columns_to_drop = ["write_time", "is_deleted", "year", "month", "day", "hour", "query_time", "api_invocation_time"]
final_df = latest_df.drop(*columns_to_drop)

print('Final count: ' + str(final_df.count()))

In [None]:
# Show final query results

final_df.show(10)

# To save query result to s3:
# OUTPUT_PATH = f"s3://{BUCKET}/{PREFIX}/test_query_output"
# final_df.write.parquet(OUTPUT_PATH, mode="overwrite")