### Use SageMaker Feature Store and Apache Spark to generate Point-In-Time queries

The following notebook uses SageMaker Feature Store and Apache Spark to build out a set of Dataframes and queries that provide a pattern for historical lookup capabilities. We will demonstrate how to build "point-in-time" feature sets by starting with raw transactional data, joining that data with records from the Offline Feature Store, and then building an Entity Dataframe to define the items we care about and the timestamp of reference. Techniques include building Spark Dataframes, using outer and inner table joins, using Dataframe filters to prune items outside our timeframe, and finally using Spark `reduceByKey` to reduce the final the dataset. 

Notes: This notebook relies on the outputs of the other two notebooks in this repo, which do the following:
- `1_generate_creditcard_transactions.ipynb` : generates raw transaction data for credit cards and consumers
- `2_create_feature_groups.ipynb` : creates two feature groups which are populated below

In [None]:
!pip install --upgrade sagemaker

In [None]:
import sagemaker as sm
sm.__version__

In [None]:
# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.functions import monotonically_increasing_id

from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.types import FractionalType
from pyspark.sql.types import DoubleType
from pyspark.sql.types import FloatType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime
import random

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())

In [None]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .config("spark.executor.memory", '1g')
    .config('spark.executor.cores', '16')
    .config("spark.driver.memory",'8g')
    .getOrCreate())

In [None]:
sc = spark.sparkContext
print(sc.version)

In [None]:
import boto3
import sagemaker
from sagemaker.session import Session

role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
boto_session = boto3.Session(region_name=region)
sagemaker_client = boto_session.client(service_name='sagemaker', region_name=region)

BUCKET = sagemaker_session.default_bucket()
print(BUCKET)


In [None]:
# Setup config variables, paths, names, etc.
import os

BASE_PREFIX = "sagemaker-featurestore-blog"

OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'

RAW_PREFIX = os.path.join(BASE_PREFIX, 'raw')
AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')

RAW_FEATURES_PATH_S3 = f"s3://{BUCKET}/{RAW_PREFIX}/"
RAW_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{RAW_PREFIX}/"
print(f'S3 Raw Transactions S3 path: {RAW_FEATURES_PATH_S3}')

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"
print(f'S3 Aggregated Data S3 Path: {AGG_FEATURES_PATH_S3}')

CONS_FEATURE_GROUP = "consumer-fg"
CARD_FEATURE_GROUP = "credit-card-fg"

In [None]:
# Feature Store Group requires ISO-8601 string format: yyyy-MM-dd'T'HH:mm:ssZ
# when the EventTime required attribute is type String

ISO_8601_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

## Generate and ingest agg features for a credit card fg and a consumer fg
This section can be moved to another preparation notebook to be run before the point in time query.

#### Let's retreive our credit card transaction data

In [None]:
from pyspark.sql.types import StructField, StructType, StringType, DoubleType, TimestampType, LongType

raw_schema = StructType([StructField('tid', StringType(), True),
                    StructField('event_time', StringType(), True),
                    StructField('cc_num', LongType(), True),
                    StructField('consumer_id', StringType(), True),
                    StructField('amount', DoubleType(), True),
                    StructField('fraud_label', StringType(), True)])

# Build path to transactions data file
raw_file = os.path.join(RAW_FEATURES_PATH_PARQUET, 'transactions.csv')
print(raw_file)

In [None]:
transactions_df = spark.read.csv(raw_file, header=True, schema=raw_schema)

In [None]:
transactions_df.printSchema()
transactions_df.count()

In [None]:
transactions_df.createOrReplaceTempView('trans')

In [None]:
from pyspark.sql.functions import lit

def agg(by_col, lookback_days, start_day, end_day):
    transactions_df.createOrReplaceTempView('trans')
    all_agg_rows = None
    for curr_day in range(start_day, end_day + 1):
        min_day = max(1, (curr_day - lookback_days +1))
        print(f'aggregating "{by_col}" for day {curr_day:02d}, look back to {min_day:02d} beginning of day...')
        start_time = f'2021-03-{min_day:02d}T00:00:00Z'
        end_time = f'2021-03-{curr_day:02d}T23:59:59Z'
        event_time = end_time

        sub_query = f'SELECT {by_col}, '
        sub_query += f'COUNT(*) as num_trans_last_{lookback_days}d, AVG(amount) as avg_amt_last_{lookback_days}d FROM trans'
        sub_query += f' where event_time >= "{start_time}" and event_time <= "{end_time}" GROUP BY {by_col}'

        d_query = f'select distinct({by_col}) from trans '

        total_query = f'select a.{by_col}, b.num_trans_last_{lookback_days}d, b.avg_amt_last_{lookback_days}d from ({d_query}) a left join ({sub_query}) b on a.{by_col} = b.{by_col}'
        print(f' Using query: {total_query}\n')
        total_df = spark.sql(total_query)

        # add a column to flag all of these records with an event time of the running of this "daily batch job"
        total_df = total_df.withColumn('event_time', lit(event_time))
#         print(f' {total_df.count()} rows')
        
        if all_agg_rows is None:
            all_agg_rows = spark.createDataFrame([], total_df.schema)
        all_agg_rows = all_agg_rows.union(total_df)
        del total_df
        
    return all_agg_rows

In [None]:
cc_num_rows_7 = agg('cc_num', 7, 1, 31)
cc_num_rows_1 = agg('cc_num', 1, 1, 31)
consumer_rows_7 = agg('consumer_id', 7, 1, 31)
consumer_rows_1 = agg('consumer_id', 1, 1, 31)

In [None]:
%%time
from pyspark.sql.functions import monotonically_increasing_id

# Add a temporary id column to each df, so that we can then join them column-wise

cc_num_rows_7 = cc_num_rows_7.withColumn("_tmp_id", monotonically_increasing_id())
cc_num_rows_1 = cc_num_rows_1.withColumn("_tmp_id", monotonically_increasing_id())

consumer_rows_7 = consumer_rows_7.withColumn("_tmp_id", monotonically_increasing_id())
consumer_rows_1 = consumer_rows_1.withColumn("_tmp_id", monotonically_increasing_id())

In [None]:
%%time
cc_num_all = cc_num_rows_7.join(cc_num_rows_1.drop('cc_num').drop('event_time'), "_tmp_id", "outer").drop('_tmp_id')
consumer_all = consumer_rows_7.join(consumer_rows_1.drop('consumer_id').drop('event_time'), "_tmp_id", "outer").drop('_tmp_id')

In [None]:
spark.sparkContext.getConf().getAll()

In [None]:
%%time
cc_num_all.orderBy(cc_num_all.event_time.desc()).show(10)
cc_num_all.count()

In [None]:
%%time
consumer_all.orderBy(consumer_all.event_time.desc()).show(10)
consumer_all.count()

### Now, we will ingest data into the Feature Store

We will use Spark to parallelize the ingest of data into the Feature Store, first for consumers and second for credit cards.

In [None]:
import boto3
from botocore.config import Config 

def ingest_df_to_fg(feature_group_name, rows, columns):
    rows = list(rows)
    session = boto3.session.Session()
    runtime = session.client(service_name='sagemaker-featurestore-runtime',
                    config=Config(retries = {'max_attempts': 10, 'mode': 'standard'}))
    for index, row in enumerate(rows):
        record = [{"FeatureName": column, "ValueAsString": str(row[column])} \
                   for column in row.__fields__ if row[column] != None]
        resp = runtime.put_record(FeatureGroupName=feature_group_name, Record=record)
        if not resp['ResponseMetadata']['HTTPStatusCode'] == 200:
            raise (f'PutRecord failed: {resp}')
    return

In [None]:
%%time

columns = ['cc_num','event_time','num_trans_last_7d','num_trans_last_1d','avg_amt_last_7d','avg_amt_last_1d']
cc_num_all.foreachPartition(lambda rows: ingest_df_to_fg(CARD_FEATURE_GROUP, rows, columns))

In [None]:
%%time

columns = ['consumer_id','event_time','num_trans_last_7d','num_trans_last_1d','avg_amt_last_7d','avg_amt_last_1d']
consumer_all.foreachPartition(lambda rows: ingest_df_to_fg(CONS_FEATURE_GROUP, rows, columns))

## Perform point-in-time correct query
We begin by creating an Entity Dataframe which identifies the consumer_ids of interest, paired with an event_time which represents our cutoff time for that entity. The consumer_id is used to join data from the raw transaction dataset, and the event_time is used within filter operations to identity the "point-in-time" correct data.

#### The Entity Dataframe will consist of real Consumer IDs and real event timestamps

First, we need to create an Entity Dataframe consisting of a list or our "target" Consumer IDs, plus a set of realistic timestamps (event_time) to run the point-in-time queries. 

In [None]:
last_1w_df = spark.sql('select * from trans where event_time >= "2021-03-25T00:00:00Z" and event_time <= "2021-03-31T23:59:59Z"')

In [None]:
cid_ts_tuples = last_1w_df.rdd.map(lambda r: (r.consumer_id, r.cc_num, r.event_time, r.amount, int(r.fraud_label))).collect()

In [None]:
len(cid_ts_tuples)

In [None]:
# Create the actual Entity Dataframe
# (e.g. the dataframe that defines our set of Consumer IDs and timestamps for our point-in-time queries)

entity_df_schema = StructType([
    StructField('consumer_id', StringType(), False),
    StructField('cc_num', StringType(), False),
    StructField('query_date', StringType(), False),
    StructField('amount', FloatType(), False),
    StructField('fraud_label', IntegerType(), False)
])

In [None]:
# Create entity data frame

entity_df = spark.createDataFrame(cid_ts_tuples, entity_df_schema)

entity_df.show(10)

#### Use Sagemaker Client to find the location of the offline store in S3
We will use the `describe_feature_group` method to lookup the S3 Uri location of the Offline Store data files.

In [None]:
# Lookup S3 Location of Offline Store

feature_group_info = sagemaker_client.describe_feature_group(FeatureGroupName=CONS_FEATURE_GROUP)
resolved_offline_store_s3_location = feature_group_info['OfflineStoreConfig']['S3StorageConfig']['ResolvedOutputS3Uri']

# Spark's Parquet file reader requires replacement of 's3' with 's3a'
offline_store_s3a_uri = resolved_offline_store_s3_location.replace("s3:", "s3a:")

print(offline_store_s3a_uri)

#### Read the offline store into a dataframe

In [None]:
%%time
# Read Offline Store data
feature_store_df = spark.read.parquet(offline_store_s3a_uri)

In [None]:
feature_store_df.printSchema()

In [None]:
%%time
feature_store_df.count()

#### Remove records marked as deleted (is_deleted attribute)

In [None]:
feature_store_active_df = feature_store_df.filter(~feature_store_df.is_deleted)

In [None]:
feature_store_active_df.select('consumer_id', 'avg_amt_last_7d', 'event_time', 'write_time', 'is_deleted').show(5)

In [None]:
row1 = feature_store_active_df.first()
test_consumer_id = row1['consumer_id']
print(test_consumer_id)

In [None]:
feature_store_active_df.select('consumer_id', 'avg_amt_last_7d', 'event_time', 'write_time', 'api_invocation_time')\
    .where(feature_store_active_df.consumer_id == test_consumer_id)\
    .orderBy('event_time','write_time')\
    .show(10,False)

#### Filter out history that is outside of our target time window

In [None]:
# NOTE: This filter is simply a performance optimization
# Filter out records from after query max_time and before staleness window prior to the min_time
# doing this prior to individual {consumer_id, joindate} filtering will speed up subsequent filters

# Choose a "staleness" window of time before which we want to ignore records
allowed_staleness_days = 14

# Eliminate history that is outside of our time window 
# this window represents the {max_time - min_time} delta, plus our staleness window

# entity_df used to define bounded time window
minmax_time = entity_df.agg(sql_min("query_date"), sql_max("query_date")).collect()
min_time, max_time = minmax_time[0]["min(query_date)"], minmax_time[0]["max(query_date)"]
print(f'min_time: {min_time}, max_time: {max_time}, staleness days: {allowed_staleness_days}')

# Via the staleness check, we are actually removing items when event_time is MORE than N days before min_time
# Usage: datediff ( enddate, startdate ) - returns days

filtered = feature_store_active_df.filter(
    (feature_store_active_df.event_time <= max_time) & 
    (datediff(lit(min_time), feature_store_active_df.event_time) <= allowed_staleness_days)
)

#### Perform the actual point-in-time correct history query

In [None]:
%%time
t_joined = (filtered.join(entity_df, filtered.consumer_id == entity_df.consumer_id, 'inner')
    .drop(entity_df.consumer_id))

# Filter out data from after query time to remove future data leakage.
# Also filter out data that is older than our allowed staleness window (days before each query time)

drop_future_and_stale_df = t_joined.filter(
    (t_joined.event_time <= entity_df.query_date)
    & (datediff(entity_df.query_date, t_joined.event_time) <= allowed_staleness_days))

drop_future_and_stale_df.select('consumer_id','query_date','avg_amt_last_7d','event_time','write_time')\
    .where(drop_future_and_stale_df.consumer_id == test_consumer_id)\
    .orderBy(col('query_date').desc(),col('event_time').desc(),col('write_time').desc())\
    .show(15,False)

# Group by record id and query timestamp, select only the latest remaining record by event time,
# using write time as a tie breaker to account for any more recent backfills or data corrections.

latest = drop_future_and_stale_df.rdd.map(lambda x: (f'{x.consumer_id}-{x.query_date}', x))\
            .reduceByKey(
                lambda x, y: x if (x.event_time, x.write_time) > (y.event_time, y.write_time) else y).values()
latest_df = latest.toDF(drop_future_and_stale_df.schema)

In [None]:
latest_df.select('consumer_id', 'query_date', 'avg_amt_last_7d', 'event_time', 'write_time')\
    .where(latest_df.consumer_id == test_consumer_id)\
    .orderBy(col('query_date').desc(),col('event_time').desc(),col('write_time').desc())\
    .show(15,False)

In [None]:
latest_df.count()

In [None]:
cols_to_drop = ('api_invocation_time','write_time','is_deleted','cc_num','year','month','day','hour')
latest_df = latest_df.drop(*cols_to_drop)

In [None]:
latest_df.select('query_date','event_time','avg_amt_last_7d','num_trans_last_7d','consumer_id').sample(
    withReplacement=False, fraction=0.001).show()

In [None]:
latest_df.printSchema()

## Create a sample training dataset with point-in-time queries against two feature groups

### Reusable function for point-in-time correct queries against a single feature group

In [None]:
def get_historical_feature_values_one_fg(
    fg_name: str, entity_df: DataFrame, spark: SparkSession,
    allowed_staleness_days: int = 14,
    remove_extra_columns: bool = True) -> DataFrame:
    
    # Get metadata for source feature group
    sm_client = boto3.Session().client(service_name='sagemaker')
    feature_group_info = sm_client.describe_feature_group(FeatureGroupName=fg_name)

    # Get the names of this feature group's RecordId and EventTime features
    record_id_name = feature_group_info['RecordIdentifierFeatureName']
    event_time_name = feature_group_info['EventTimeFeatureName']
    
    # Get S3 Location of this feature group's offline store. 
    # Note Spark's Parquet file reader requires replacement of 's3' with 's3a'
    resolved_offline_store_s3_location = \
        feature_group_info['OfflineStoreConfig']['S3StorageConfig']['ResolvedOutputS3Uri']
    offline_store_s3a_uri = resolved_offline_store_s3_location.replace("s3:", "s3a:")

    # Read the offline store into a dataframe
    feature_store_df = spark.read.parquet(offline_store_s3a_uri)
    
    # Filter out deleted records, if any
    fs_active_df = feature_store_df.filter(~feature_store_df.is_deleted)
    
    # Determine min and max time of query timestamps
    minmax_time = entity_df.agg(sql_min("query_time"), sql_max("query_time")).collect()
    min_time, max_time = minmax_time[0]["min(query_time)"], minmax_time[0]["max(query_time)"]
    
    # Remove all rows that are outside of our time window, allowing for a buffer of staleness days
    filtered_df = fs_active_df.filter(
        (fs_active_df[event_time_name] <= max_time) & 
        (datediff(lit(min_time), fs_active_df[event_time_name]) <= allowed_staleness_days))
    
    # Join on record id between the input entity dataframe and the feature history dataframe
    joined_df = filtered_df.join(entity_df, 
                              filtered_df[record_id_name] == entity_df[record_id_name], 'inner')\
                                .drop(entity_df[record_id_name])

    # Filter out data from after query time to remove future data leakage
    # Also filter out data that is beyond our allowed staleness window (days before each query time)
    drop_future_and_stale_df = joined_df.filter(
        (joined_df[event_time_name] <= entity_df.query_time)
        & (datediff(entity_df.query_time, joined_df[event_time_name]) <= allowed_staleness_days))

    # Group by composite key (to uniquely identify the combination of an entity id and a query timestamp),
    # and keep only the very latest remaining feature vector coming closest to the input timestamp.
    # Use write time as a tie breaker to account for any more recent backfills or data corrections.
    latest = drop_future_and_stale_df.rdd.map(lambda x: (f'{x[record_id_name]}-{x.query_time}', x))\
                .reduceByKey(
                    lambda x, y: x if (x[event_time_name], x.write_time) > 
                                      (y[event_time_name], y.write_time) else y).values()
    latest_df = latest.toDF(drop_future_and_stale_df.schema)
    
    # Clean up excess columns
    if remove_extra_columns:
        cols_to_drop = ('api_invocation_time','write_time','is_deleted',
                        record_id_name,'query_time',event_time_name,
                        'year','month','day','hour')
        latest_df = latest_df.drop(*cols_to_drop)
    
    # Return results of point in time query
    return latest_df

### Use a handful of transactions from our transactions history as a base 
Note that we select a pair of entity identifiers, `consumer_id` and `cc_num` to drive corresponding 
queries against feature value history for those entities. We also add a monotonically increasing temporary
identifier to the query dataset. This will let us do an accurate final join of the results of each
entity-specific point-in-time query result into a combined training dataset containing multiple 
feature vectors for each transaction.

In [None]:
sample_count = 10
my_entity_df = transactions_df.select('consumer_id','cc_num','event_time','amount','fraud_label')\
                    .orderBy(col('event_time').desc()).limit(sample_count)
my_entity_cols = ['consumer_id','cc_num','query_time','amount','fraud_label']
my_entity_df = my_entity_df.toDF(*my_entity_cols)

my_entity_df = my_entity_df.withColumn("_tmp_id", monotonically_increasing_id())
my_entity_df.drop('_tmp_id').show(10)

### Do a point-in-time correct query to retrieve Consumer features for each transaction

In [None]:
cons_df = get_historical_feature_values_one_fg(CONS_FEATURE_GROUP, my_entity_df, spark)
cons_df.select('num_trans_last_7d', 'avg_amt_last_7d', 'num_trans_last_1d', 'avg_amt_last_1d', 
               'fraud_label').show(sample_count, False)

### Do a point-in-time correct query to retrieve Credit Card features for each transaction

In [None]:
card_df = get_historical_feature_values_one_fg(CARD_FEATURE_GROUP, my_entity_df, spark)
card_df.select('num_trans_last_7d', 'avg_amt_last_7d', 'num_trans_last_1d', 'avg_amt_last_1d', 
               'fraud_label').show(sample_count, False)

### Join the feature vectors from each feature group to form the final training dataset

In [None]:
# Drop id's, as they aren't used for training
cols_to_drop = ('cc_num','consumer_id')

# In this example, the column names are not unique. Add a prefix so they can be distinct in the training dataset
new_card_col_names = ('card_num_trans_last_7d', 'card_avg_amt_last_7d', 
                 'card_num_trans_last_1d', 'card_avg_amt_last_1d', 'amount', 'fraud_label', '_tmp_id')
card_df = card_df.drop(*cols_to_drop).toDF(*new_card_col_names)
card_df.show(sample_count)

new_cons_col_names = ('cons_num_trans_last_7d', 'cons_avg_amt_last_7d', 
                 'cons_num_trans_last_1d', 'cons_avg_amt_last_1d', 'amount', 'fraud_label', '_tmp_id')
cons_df = cons_df.drop(*cols_to_drop).toDF(*new_cons_col_names)
cons_df.show(sample_count)

In [None]:
# Join the feature vectors from each entity into a single training dataset
training_df = card_df.drop('fraud_label').drop('amount').join(cons_df, "_tmp_id", "outer").drop("_tmp_id").fillna(0)
training_df.show(sample_count)