### Use SageMaker Feature Store and Apache Spark to generate point-in-time queries
The following notebook uses SageMaker Feature Store and Apache Spark to build out a set of Dataframes and queries that provide a pattern for historical lookup capabilities. We will demonstrate how to build a "point-in-time" feature sets by starting with raw transactional data, joining that data with records from the Offline Store, and then building an "entity" dataset to define the items we care about and the timestamp of reference. Techniques include building Spark Dataframes, using outer and inner table joins, using Dataframe filters to prune items outside our timeframe, and finally using Spark `reduceByKey` to reduce the final the dataset. 

In [None]:
!pip install --upgrade sagemaker

In [1]:
import sagemaker as sm
sm.__version__

'2.42.0'

In [2]:
# Import pyspark and build Spark session

from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.functions import datediff
from pyspark.sql.functions import lit
from pyspark.sql.functions import col
from pyspark.sql.functions import max as sql_max
from pyspark.sql.functions import min as sql_min
from pyspark.sql.functions import monotonically_increasing_id

from pyspark.sql.types import StringType
from pyspark.sql.types import IntegerType
from pyspark.sql.types import FractionalType
from pyspark.sql.types import DoubleType
from pyspark.sql.types import FloatType
from pyspark.sql.types import StructField
from pyspark.sql.types import StructType

from pyspark import SparkContext, SparkConf
import sagemaker_pyspark
import datetime
import random

# Configure Spark to use the SageMaker Spark dependency jars
classpath = ":".join(sagemaker_pyspark.classpath_jars())

In [3]:
spark = (SparkSession
    .builder
    .config("spark.driver.extraClassPath", classpath)
    .config("spark.executor.memory", '1g')
    .config('spark.executor.cores', '16')
    .config("spark.driver.memory",'8g')
    .getOrCreate())

In [4]:
sc = spark.sparkContext
print(sc.version)

2.4.0


In [5]:
import boto3
import sagemaker
from sagemaker.session import Session

role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
boto_session = boto3.Session(region_name=region)

sagemaker_client = boto_session.client(service_name='sagemaker', region_name=region)

# If you need an instance of FeatureStore runtime:
#featurestore_runtime = boto_session.client(service_name='sagemaker-featurestore-runtime', region_name=region)

#feature_store_session = Session(
#    boto_session=boto_session,
#    sagemaker_client=sagemaker_client,
#    sagemaker_featurestore_runtime_client=featurestore_runtime
#)

BUCKET = sagemaker_session.default_bucket()
print(BUCKET)


sagemaker-us-east-1-572539092864


In [6]:
# Setup config variables, paths, names, etc.
import os

BASE_PREFIX = "sagemaker-featurestore-blog"

OFFLINE_STORE_BASE_URI = f's3://{BUCKET}/{BASE_PREFIX}'

RAW_PREFIX = os.path.join(BASE_PREFIX, 'raw')
AGG_PREFIX = os.path.join(BASE_PREFIX, 'aggregated')

RAW_FEATURES_PATH_S3 = f"s3://{BUCKET}/{RAW_PREFIX}/"
RAW_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{RAW_PREFIX}/"
print(f'S3 Raw Transactions S3 path: {RAW_FEATURES_PATH_S3}')

AGG_FEATURES_PATH_S3 = f"s3://{BUCKET}/{AGG_PREFIX}/"
AGG_FEATURES_PATH_PARQUET = f"s3a://{BUCKET}/{AGG_PREFIX}/"
print(f'S3 Aggregated Data S3 Path: {AGG_FEATURES_PATH_S3}')

CONS_FEATURE_GROUP = "consumer-fg"
CARD_FEATURE_GROUP = "credit-card-fg"

S3 Raw Transactions S3 path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/raw/
S3 Aggregated Data S3 Path: s3://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/aggregated/


In [7]:
# Feature Store Group requires ISO-8601 string format: yyyy-MM-dd'T'HH:mm:ssZ
# when the EventTime required attribute is type String

ISO_8601_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

## Generate and ingest agg features for a credit card fg and a consumer fg
This section can be moved to another preparation notebook to be run before the point in time query.

#### Let's retreive our credit card transaction data

In [8]:
from pyspark.sql.types import StructField, StructType, StringType, DoubleType, TimestampType, LongType

raw_schema = StructType([StructField('tid', StringType(), True),
                    StructField('event_time', StringType(), True),
                    StructField('cc_num', LongType(), True),
                    StructField('consumer_id', StringType(), True),
                    StructField('amount', DoubleType(), True),
                    StructField('fraud_label', StringType(), True)])

# Build path to transactions data file
raw_file = os.path.join(RAW_FEATURES_PATH_PARQUET, 'transactions.csv')
print(raw_file)

s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/raw/transactions.csv


In [9]:
transactions_df = spark.read.csv(raw_file, header=True, schema=raw_schema)

In [10]:
transactions_df.printSchema()
transactions_df.count()

root
 |-- tid: string (nullable = true)
 |-- event_time: string (nullable = true)
 |-- cc_num: long (nullable = true)
 |-- consumer_id: string (nullable = true)
 |-- amount: double (nullable = true)
 |-- fraud_label: string (nullable = true)



100000

In [11]:
transactions_df.createOrReplaceTempView('trans')

In [12]:
from pyspark.sql.functions import lit

def agg(by_col, lookback_days, start_day, end_day):
    transactions_df.createOrReplaceTempView('trans')
    all_agg_rows = None
    for curr_day in range(start_day, end_day + 1):
        min_day = max(1, (curr_day - lookback_days +1))
        print(f'aggregating "{by_col}" for day {curr_day:02d}, look back to {min_day:02d} beginning of day...')
        start_time = f'2021-03-{min_day:02d}T00:00:00Z'
        end_time = f'2021-03-{curr_day:02d}T23:59:59Z'
        event_time = end_time

        sub_query = f'SELECT {by_col}, '
        sub_query += f'COUNT(*) as num_trans_last_{lookback_days}d, AVG(amount) as avg_amt_last_{lookback_days}d FROM trans'
        sub_query += f' where event_time >= "{start_time}" and event_time <= "{end_time}" GROUP BY {by_col}'

        d_query = f'select distinct({by_col}) from trans '

        total_query = f'select a.{by_col}, b.num_trans_last_{lookback_days}d, b.avg_amt_last_{lookback_days}d from ({d_query}) a left join ({sub_query}) b on a.{by_col} = b.{by_col}'
        print(f' Using query: {total_query}\n')
        total_df = spark.sql(total_query)

        # add a column to flag all of these records with an event time of the running of this "daily batch job"
        total_df = total_df.withColumn('event_time', lit(event_time))
#         print(f' {total_df.count()} rows')
        
        if all_agg_rows is None:
            all_agg_rows = spark.createDataFrame([], total_df.schema)
        all_agg_rows = all_agg_rows.union(total_df)
        del total_df
        
    return all_agg_rows

In [13]:
cc_num_rows_7 = agg('cc_num', 7, 1, 31)
cc_num_rows_1 = agg('cc_num', 1, 1, 31)
consumer_rows_7 = agg('consumer_id', 7, 1, 31)
consumer_rows_1 = agg('consumer_id', 1, 1, 31)

aggregating "cc_num" for day 01, look back to 01 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(cc_num) from trans ) a left join (SELECT cc_num, COUNT(*) as num_trans_last_7d, AVG(amount) as avg_amt_last_7d FROM trans where event_time >= "2021-03-01T00:00:00Z" and event_time <= "2021-03-01T23:59:59Z" GROUP BY cc_num) b on a.cc_num = b.cc_num

aggregating "cc_num" for day 02, look back to 01 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(cc_num) from trans ) a left join (SELECT cc_num, COUNT(*) as num_trans_last_7d, AVG(amount) as avg_amt_last_7d FROM trans where event_time >= "2021-03-01T00:00:00Z" and event_time <= "2021-03-02T23:59:59Z" GROUP BY cc_num) b on a.cc_num = b.cc_num

aggregating "cc_num" for day 03, look back to 01 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(cc_num) from trans 

aggregating "cc_num" for day 23, look back to 17 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(cc_num) from trans ) a left join (SELECT cc_num, COUNT(*) as num_trans_last_7d, AVG(amount) as avg_amt_last_7d FROM trans where event_time >= "2021-03-17T00:00:00Z" and event_time <= "2021-03-23T23:59:59Z" GROUP BY cc_num) b on a.cc_num = b.cc_num

aggregating "cc_num" for day 24, look back to 18 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(cc_num) from trans ) a left join (SELECT cc_num, COUNT(*) as num_trans_last_7d, AVG(amount) as avg_amt_last_7d FROM trans where event_time >= "2021-03-18T00:00:00Z" and event_time <= "2021-03-24T23:59:59Z" GROUP BY cc_num) b on a.cc_num = b.cc_num

aggregating "cc_num" for day 25, look back to 19 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(cc_num) from trans 

aggregating "cc_num" for day 14, look back to 14 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_1d, b.avg_amt_last_1d from (select distinct(cc_num) from trans ) a left join (SELECT cc_num, COUNT(*) as num_trans_last_1d, AVG(amount) as avg_amt_last_1d FROM trans where event_time >= "2021-03-14T00:00:00Z" and event_time <= "2021-03-14T23:59:59Z" GROUP BY cc_num) b on a.cc_num = b.cc_num

aggregating "cc_num" for day 15, look back to 15 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_1d, b.avg_amt_last_1d from (select distinct(cc_num) from trans ) a left join (SELECT cc_num, COUNT(*) as num_trans_last_1d, AVG(amount) as avg_amt_last_1d FROM trans where event_time >= "2021-03-15T00:00:00Z" and event_time <= "2021-03-15T23:59:59Z" GROUP BY cc_num) b on a.cc_num = b.cc_num

aggregating "cc_num" for day 16, look back to 16 beginning of day...
 Using query: select a.cc_num, b.num_trans_last_1d, b.avg_amt_last_1d from (select distinct(cc_num) from trans 

aggregating "consumer_id" for day 15, look back to 09 beginning of day...
 Using query: select a.consumer_id, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(consumer_id) from trans ) a left join (SELECT consumer_id, COUNT(*) as num_trans_last_7d, AVG(amount) as avg_amt_last_7d FROM trans where event_time >= "2021-03-09T00:00:00Z" and event_time <= "2021-03-15T23:59:59Z" GROUP BY consumer_id) b on a.consumer_id = b.consumer_id

aggregating "consumer_id" for day 16, look back to 10 beginning of day...
 Using query: select a.consumer_id, b.num_trans_last_7d, b.avg_amt_last_7d from (select distinct(consumer_id) from trans ) a left join (SELECT consumer_id, COUNT(*) as num_trans_last_7d, AVG(amount) as avg_amt_last_7d FROM trans where event_time >= "2021-03-10T00:00:00Z" and event_time <= "2021-03-16T23:59:59Z" GROUP BY consumer_id) b on a.consumer_id = b.consumer_id

aggregating "consumer_id" for day 17, look back to 11 beginning of day...
 Using query: select a.consumer_id, 

aggregating "consumer_id" for day 16, look back to 16 beginning of day...
 Using query: select a.consumer_id, b.num_trans_last_1d, b.avg_amt_last_1d from (select distinct(consumer_id) from trans ) a left join (SELECT consumer_id, COUNT(*) as num_trans_last_1d, AVG(amount) as avg_amt_last_1d FROM trans where event_time >= "2021-03-16T00:00:00Z" and event_time <= "2021-03-16T23:59:59Z" GROUP BY consumer_id) b on a.consumer_id = b.consumer_id

aggregating "consumer_id" for day 17, look back to 17 beginning of day...
 Using query: select a.consumer_id, b.num_trans_last_1d, b.avg_amt_last_1d from (select distinct(consumer_id) from trans ) a left join (SELECT consumer_id, COUNT(*) as num_trans_last_1d, AVG(amount) as avg_amt_last_1d FROM trans where event_time >= "2021-03-17T00:00:00Z" and event_time <= "2021-03-17T23:59:59Z" GROUP BY consumer_id) b on a.consumer_id = b.consumer_id

aggregating "consumer_id" for day 18, look back to 18 beginning of day...
 Using query: select a.consumer_id, 

In [14]:
%%time
from pyspark.sql.functions import monotonically_increasing_id

# Add a temporary id column to each df, so that we can then join them column-wise

cc_num_rows_7 = cc_num_rows_7.withColumn("_tmp_id", monotonically_increasing_id())
cc_num_rows_1 = cc_num_rows_1.withColumn("_tmp_id", monotonically_increasing_id())

consumer_rows_7 = consumer_rows_7.withColumn("_tmp_id", monotonically_increasing_id())
consumer_rows_1 = consumer_rows_1.withColumn("_tmp_id", monotonically_increasing_id())

CPU times: user 1.29 ms, sys: 642 µs, total: 1.93 ms
Wall time: 11.5 ms


In [15]:
%%time
cc_num_all = cc_num_rows_7.join(cc_num_rows_1.drop('cc_num').drop('event_time'), "_tmp_id", "outer").drop('_tmp_id')
consumer_all = consumer_rows_7.join(consumer_rows_1.drop('consumer_id').drop('event_time'), "_tmp_id", "outer").drop('_tmp_id')

CPU times: user 1.54 ms, sys: 761 µs, total: 2.3 ms
Wall time: 28.2 ms


In [16]:
spark.sparkContext.getConf().getAll()

[('spark.app.id', 'local-1622040338807'),
 ('spark.driver.extraClassPath',
  '/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/aws-java-sdk-sagemakerruntime-1.11.835.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/hadoop-auth-2.8.1.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/hadoop-aws-2.8.1.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/sagemaker-spark_2.11-spark_2.4.0-1.4.2.dev0.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/aws-java-sdk-sts-1.11.835.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/hadoop-annotations-2.8.1.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspark/jars/hadoop-common-2.8.1.jar:/home/ec2-user/anaconda3/envs/python3/lib/python3.6/site-packages/sagemaker_pyspar

In [17]:
%%time
cc_num_all.orderBy(cc_num_all.event_time.desc()).show(10)
cc_num_all.count()

+----------------+-----------------+------------------+--------------------+-----------------+------------------+
|          cc_num|num_trans_last_7d|   avg_amt_last_7d|          event_time|num_trans_last_1d|   avg_amt_last_1d|
+----------------+-----------------+------------------+--------------------+-----------------+------------------+
|4457570037098706|                5|           167.798|2021-03-31T23:59:59Z|             null|              null|
|4733882799298846|               10|203.76700000000002|2021-03-31T23:59:59Z|                2|38.330000000000005|
|4491810891527863|               11| 785.8854545454545|2021-03-31T23:59:59Z|                3|221.84666666666666|
|4775078220934872|               12|1014.2400000000002|2021-03-31T23:59:59Z|                6|1604.6950000000004|
|4693114584822046|               15| 1366.900666666667|2021-03-31T23:59:59Z|             null|              null|
|4966380846906071|               11| 184.2709090909091|2021-03-31T23:59:59Z|            

62000

In [18]:
%%time
consumer_all.orderBy(consumer_all.event_time.desc()).show(10)
consumer_all.count()

+------------------+-----------------+------------------+--------------------+-----------------+-----------------+
|       consumer_id|num_trans_last_7d|   avg_amt_last_7d|          event_time|num_trans_last_1d|  avg_amt_last_1d|
+------------------+-----------------+------------------+--------------------+-----------------+-----------------+
|GGVX95242869782633|               20| 948.3839999999998|2021-03-31T23:59:59Z|                3|2383.983333333333|
|QRTQ16845213506987|               19| 209.7015789473684|2021-03-31T23:59:59Z|                3|61.22666666666667|
|HPQZ79432746178012|               10|          1083.131|2021-03-31T23:59:59Z|                1|           264.36|
|YKLN71031286579847|               20|           1103.91|2021-03-31T23:59:59Z|                3|            44.01|
|OHGZ44766020135918|               31|490.17161290322593|2021-03-31T23:59:59Z|                5|          223.952|
|TDMV74965427985907|               40| 354.3527499999999|2021-03-31T23:59:59Z|  

31000

### Now, we will ingest data into the Feature Store

We will use Spark to parallelize the ingest of data into the Feature Store, first for consumers and second for credit cards.

In [19]:
import boto3
from botocore.config import Config 

def ingest_df_to_fg(feature_group_name, rows, columns):
    rows = list(rows)
    session = boto3.session.Session()
    runtime = session.client(service_name='sagemaker-featurestore-runtime',
                    config=Config(retries = {'max_attempts': 10, 'mode': 'standard'}))
    for index, row in enumerate(rows):
        record = [{"FeatureName": column, "ValueAsString": str(row[column])} \
                   for column in row.__fields__ if row[column] != None]
        resp = runtime.put_record(FeatureGroupName=feature_group_name, Record=record)
        if not resp['ResponseMetadata']['HTTPStatusCode'] == 200:
            raise (f'PutRecord failed: {resp}')
    return

In [20]:
%%time

columns = ['cc_num','event_time','num_trans_last_7d','num_trans_last_1d','avg_amt_last_7d','avg_amt_last_1d']
cc_num_all.foreachPartition(lambda rows: ingest_df_to_fg(CARD_FEATURE_GROUP, rows, columns))

CPU times: user 491 ms, sys: 245 ms, total: 736 ms
Wall time: 1min 25s


In [21]:
%%time

columns = ['consumer_id','event_time','num_trans_last_7d','num_trans_last_1d','avg_amt_last_7d','avg_amt_last_1d']
consumer_all.foreachPartition(lambda rows: ingest_df_to_fg(CONS_FEATURE_GROUP, rows, columns))

CPU times: user 501 ms, sys: 286 ms, total: 787 ms
Wall time: 56 s


## Perform point-in-time correct query
We begin by creating an Entity Dataframe which identifies the consumer_ids of interest, coupled with an event_time which represents our cutoff time for that entity. 

#### The Entity Dataframe will consist of real Consumer IDs and real event timestamps

First, we need to create an Entity Dataframe consisting of a list or our "target" Consumer IDs, plus a set of realistic timestamps (event_time) to run the point-in-time queries. 

In [22]:
last_1w_df = spark.sql('select * from trans where event_time >= "2021-03-25T00:00:00Z" and event_time <= "2021-03-31T23:59:59Z"')

In [23]:
cid_ts_tuples = last_1w_df.rdd.map(lambda r: (r.consumer_id, r.cc_num, r.event_time, r.amount, int(r.fraud_label))).collect()

In [24]:
len(cid_ts_tuples)

22547

In [25]:
# Create the actual Entity Dataframe
# (e.g. the dataframe that defines our set of Consumer IDs and timestamps for our point-in-time queries)

entity_df_schema = StructType([
    StructField('consumer_id', StringType(), False),
    StructField('cc_num', StringType(), False),
    StructField('query_date', StringType(), False),
    StructField('amount', FloatType(), False),
    StructField('fraud_label', IntegerType(), False)
])

In [26]:
# Create entity data frame

entity_df = spark.createDataFrame(cid_ts_tuples, entity_df_schema)
entity_df = entity_df.withColumn("_tmp_id", monotonically_increasing_id())

entity_df.show(10)

+------------------+----------------+--------------------+------+-----------+-------+
|       consumer_id|          cc_num|          query_date|amount|fraud_label|_tmp_id|
+------------------+----------------+--------------------+------+-----------+-------+
|QFFO43633815728040|4538707466326165|2021-03-25T00:00:01Z|  1.15|          0|      0|
|UUEG06702648357115|4631117361256523|2021-03-25T00:00:03Z|668.13|          0|      1|
|TJOJ29599331508229|4717798993183248|2021-03-25T00:00:05Z| 60.88|          0|      2|
|MFZR34661365402183|4865185401569996|2021-03-25T00:00:10Z|453.83|          0|      3|
|WERB44601853246125|4375373891114563|2021-03-25T00:00:36Z|896.26|          0|      4|
|IELN97823885632568|4597025692084773|2021-03-25T00:00:39Z| 58.31|          0|      5|
|FYTX19736334250684|4988283988187142|2021-03-25T00:02:42Z|5724.3|          0|      6|
|MFUQ42148144236605|4387073979354976|2021-03-25T00:03:59Z|  57.9|          0|      7|
|LHTC75807017825780|4238424338006123|2021-03-25T00:04:

#### Use Sagemaker Client to find the location of the offline store in S3
We will use the `describe_feature_group` method to lookup the S3 Uri location of the Offline Store data files.

In [27]:
# Lookup S3 Location of Offline Store

feature_group_info = sagemaker_client.describe_feature_group(FeatureGroupName=CONS_FEATURE_GROUP)
resolved_offline_store_s3_location = feature_group_info['OfflineStoreConfig']['S3StorageConfig']['ResolvedOutputS3Uri']

# Spark's Parquet file reader requires replacement of 's3' with 's3a'
offline_store_s3a_uri = resolved_offline_store_s3_location.replace("s3:", "s3a:")

print(offline_store_s3a_uri)

s3a://sagemaker-us-east-1-572539092864/sagemaker-featurestore-blog/572539092864/sagemaker/us-east-1/offline-store/consumer-fg-1622040537/data


#### Read the offline store into a dataframe

In [28]:
%%time
# Read Offline Store data
feature_store_df = spark.read.parquet(offline_store_s3a_uri)

CPU times: user 1.35 ms, sys: 725 µs, total: 2.07 ms
Wall time: 5.9 s


In [29]:
feature_store_df.printSchema()

root
 |-- consumer_id: string (nullable = true)
 |-- num_trans_last_7d: long (nullable = true)
 |-- avg_amt_last_7d: double (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- write_time: timestamp (nullable = true)
 |-- api_invocation_time: timestamp (nullable = true)
 |-- is_deleted: boolean (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- hour: integer (nullable = true)



In [30]:
%%time
feature_store_df.count()

CPU times: user 844 µs, sys: 454 µs, total: 1.3 ms
Wall time: 4.06 s


27497

#### Remove records marked as deleted (is_deleted attribute)

In [31]:
feature_store_active_df = feature_store_df.filter(~feature_store_df.is_deleted)

In [32]:
feature_store_active_df.select('consumer_id', 'avg_amt_last_7d', 'event_time', 'write_time', 'is_deleted').show(5)

+------------------+------------------+--------------------+--------------------+----------+
|       consumer_id|   avg_amt_last_7d|          event_time|          write_time|is_deleted|
+------------------+------------------+--------------------+--------------------+----------+
|MOXJ24964715530559| 75.09666666666666|2021-03-02T23:59:59Z|2021-05-26 15:06:...|     false|
|ACQG48444531254282| 5513.706666666666|2021-03-02T23:59:59Z|2021-05-26 15:06:...|     false|
|OOZF04050723251487| 141.0366666666667|2021-03-02T23:59:59Z|2021-05-26 15:06:...|     false|
|ALVF79373394142214|             350.0|2021-03-02T23:59:59Z|2021-05-26 15:06:...|     false|
|ZEQV63475057933816|2596.5850000000005|2021-03-02T23:59:59Z|2021-05-26 15:06:...|     false|
+------------------+------------------+--------------------+--------------------+----------+
only showing top 5 rows



In [33]:
row1 = feature_store_active_df.first()
test_consumer_id = row1['consumer_id']
print(test_consumer_id)

MOXJ24964715530559


In [35]:
feature_store_active_df.select('consumer_id', 'avg_amt_last_7d', 'event_time', 'write_time', 'api_invocation_time')\
    .where(feature_store_active_df.consumer_id == test_consumer_id)\
    .orderBy('event_time','write_time')\
    .show(10,False)

+------------------+------------------+--------------------+-----------------------+-------------------+
|consumer_id       |avg_amt_last_7d   |event_time          |write_time             |api_invocation_time|
+------------------+------------------+--------------------+-----------------------+-------------------+
|MOXJ24964715530559|85.655            |2021-03-01T23:59:59Z|2021-05-26 15:06:50.869|2021-05-26 15:06:25|
|MOXJ24964715530559|75.09666666666666 |2021-03-02T23:59:59Z|2021-05-26 15:06:50.843|2021-05-26 15:06:25|
|MOXJ24964715530559|125.03400000000002|2021-03-03T23:59:59Z|2021-05-26 15:06:51.017|2021-05-26 15:06:37|
|MOXJ24964715530559|226.69500000000002|2021-03-04T23:59:59Z|2021-05-26 15:06:51.017|2021-05-26 15:06:25|
|MOXJ24964715530559|184.52625         |2021-03-05T23:59:59Z|2021-05-26 15:06:51.015|2021-05-26 15:06:29|
|MOXJ24964715530559|414.56666666666666|2021-03-06T23:59:59Z|2021-05-26 15:07:17.895|2021-05-26 15:06:41|
|MOXJ24964715530559|579.6627272727272 |2021-03-07T23:59

#### Filter out history that is outside of our target time window

In [36]:
# NOTE: This filter is simply a performance optimization
# Filter out records from after query max_time and before staleness window prior to the min_time
# doing this prior to individual {consumer_id, joindate} filtering will speed up subsequent filters

# Choose a "staleness" window of time before which we want to ignore records
allowed_staleness_days = 14

# Eliminate history that is outside of our time window 
# this window represents the {max_time - min_time} delta, plus our staleness window

# entity_df used to define bounded time window
minmax_time = entity_df.agg(sql_min("query_date"), sql_max("query_date")).collect()
min_time, max_time = minmax_time[0]["min(query_date)"], minmax_time[0]["max(query_date)"]
print(f'min_time: {min_time}, max_time: {max_time}, staleness days: {allowed_staleness_days}')

# Via the staleness check, we are actually removing items when event_time is MORE than N days before min_time
# Usage: datediff ( enddate, startdate ) - returns days

filtered = feature_store_active_df.filter(
    (feature_store_active_df.event_time <= max_time) & 
    (datediff(lit(min_time), feature_store_active_df.event_time) <= allowed_staleness_days)
)

min_time: 2021-03-25T00:00:01Z, max_time: 2021-03-31T23:58:22Z, staleness days: 14


#### Perform the actual point-in-time correct history query

In [44]:
%%time
t_joined = (filtered.join(entity_df, filtered.consumer_id == entity_df.consumer_id, 'inner')
    .drop(entity_df.consumer_id)
    .drop(entity_df._tmp_id))

# Filter out data from after query time to remove future data leakage.
# Also filter out data that is older than our allowed staleness window (days before each query time)

drop_future_and_stale_df = t_joined.filter(
    (t_joined.event_time <= entity_df.query_date)
    & (datediff(entity_df.query_date, t_joined.event_time) <= allowed_staleness_days))

drop_future_and_stale_df.select('consumer_id','query_date','avg_amt_last_7d','event_time','write_time')\
    .where(drop_future_and_stale_df.consumer_id == test_consumer_id)\
    .orderBy(col('query_date').desc(),col('event_time').desc(),col('write_time').desc())\
    .show(15,False)

# Group by record id and query timestamp, select only the latest remaining record by event time,
# using write time as a tie breaker to account for any more recent backfills or data corrections.

latest = drop_future_and_stale_df.rdd.map(lambda x: (f'{x.consumer_id}-{x.query_date}', x))\
            .reduceByKey(
                lambda x, y: x if (x.event_time, x.write_time) > (y.event_time, y.write_time) else y).values()
latest_df = latest.toDF(drop_future_and_stale_df.schema)

+------------------+--------------------+------------------+--------------------+-----------------------+
|consumer_id       |query_date          |avg_amt_last_7d   |event_time          |write_time             |
+------------------+--------------------+------------------+--------------------+-----------------------+
|MOXJ24964715530559|2021-03-31T09:42:23Z|338.835           |2021-03-29T23:59:59Z|2021-05-26 15:07:17.894|
|MOXJ24964715530559|2021-03-31T09:42:23Z|326.344           |2021-03-28T23:59:59Z|2021-05-26 15:06:50.839|
|MOXJ24964715530559|2021-03-31T09:42:23Z|415.32454545454544|2021-03-27T23:59:59Z|2021-05-26 15:06:51.072|
|MOXJ24964715530559|2021-03-31T09:42:23Z|305.768           |2021-03-26T23:59:59Z|2021-05-26 15:06:51.065|
|MOXJ24964715530559|2021-03-31T09:42:23Z|875.9472727272729 |2021-03-25T23:59:59Z|2021-05-26 15:07:17.898|
|MOXJ24964715530559|2021-03-31T09:42:23Z|948.347           |2021-03-24T23:59:59Z|2021-05-26 15:06:50.968|
|MOXJ24964715530559|2021-03-31T09:42:23Z|942.4

In [45]:
latest_df.select('consumer_id', 'query_date', 'avg_amt_last_7d', 'event_time', 'write_time')\
    .where(latest_df.consumer_id == test_consumer_id)\
    .orderBy(col('query_date').desc(),col('event_time').desc(),col('write_time').desc())\
    .show(15,False)

+------------------+--------------------+------------------+--------------------+-----------------------+
|consumer_id       |query_date          |avg_amt_last_7d   |event_time          |write_time             |
+------------------+--------------------+------------------+--------------------+-----------------------+
|MOXJ24964715530559|2021-03-31T09:42:23Z|338.835           |2021-03-29T23:59:59Z|2021-05-26 15:07:17.894|
|MOXJ24964715530559|2021-03-30T17:08:23Z|338.835           |2021-03-29T23:59:59Z|2021-05-26 15:07:17.894|
|MOXJ24964715530559|2021-03-30T00:27:10Z|338.835           |2021-03-29T23:59:59Z|2021-05-26 15:07:17.894|
|MOXJ24964715530559|2021-03-29T08:10:42Z|326.344           |2021-03-28T23:59:59Z|2021-05-26 15:06:50.839|
|MOXJ24964715530559|2021-03-28T15:41:12Z|415.32454545454544|2021-03-27T23:59:59Z|2021-05-26 15:06:51.072|
|MOXJ24964715530559|2021-03-27T22:49:54Z|305.768           |2021-03-26T23:59:59Z|2021-05-26 15:06:51.065|
|MOXJ24964715530559|2021-03-27T05:41:33Z|305.7

In [46]:
latest_df.count()

22546

In [47]:
cols_to_drop = ('api_invocation_time','write_time','is_deleted','cc_num','year','month','day','hour')
latest_df = latest_df.drop(*cols_to_drop)

In [48]:
latest_df.select('query_date','event_time','avg_amt_last_7d','num_trans_last_7d','consumer_id').show(5)

+--------------------+--------------------+------------------+-----------------+------------------+
|          query_date|          event_time|   avg_amt_last_7d|num_trans_last_7d|       consumer_id|
+--------------------+--------------------+------------------+-----------------+------------------+
|2021-03-25T00:00:03Z|2021-03-24T23:59:59Z| 691.0175999999999|               25|UUEG06702648357115|
|2021-03-25T00:00:05Z|2021-03-24T23:59:59Z|465.65999999999997|               30|TJOJ29599331508229|
|2021-03-25T00:09:12Z|2021-03-24T23:59:59Z|          432.0545|               20|MYWU90501165273252|
|2021-03-25T00:15:37Z|2021-03-24T23:59:59Z|          855.8095|               40|ZLUQ21854565784963|
|2021-03-25T00:18:08Z|2021-03-24T23:59:59Z| 572.3636666666664|               60|TMBQ32501873277485|
+--------------------+--------------------+------------------+-----------------+------------------+
only showing top 5 rows



In [49]:
latest_df.printSchema()

root
 |-- consumer_id: string (nullable = true)
 |-- num_trans_last_7d: long (nullable = true)
 |-- avg_amt_last_7d: double (nullable = true)
 |-- num_trans_last_1d: long (nullable = true)
 |-- avg_amt_last_1d: double (nullable = true)
 |-- event_time: string (nullable = true)
 |-- query_date: string (nullable = false)
 |-- amount: float (nullable = false)
 |-- fraud_label: integer (nullable = false)



## Create a sample training dataset with point-in-time queries against two feature groups

### Reusable function for point-in-time correct queries against a single feature group

In [None]:
def get_historical_feature_values_one_fg(
    fg_name: str, entity_df: DataFrame, spark: SparkSession,
    allowed_staleness_days: int = 14,
    remove_extra_columns: bool = True) -> DataFrame:
    
    # Get metadata for source feature group
    sm_client = boto3.Session().client(service_name='sagemaker')
    feature_group_info = sm_client.describe_feature_group(FeatureGroupName=fg_name)

    # Get the names of this feature group's RecordId and EventTime features
    record_id_name = feature_group_info['RecordIdentifierFeatureName']
    event_time_name = feature_group_info['EventTimeFeatureName']
    
    # Get S3 Location of this feature group's offline store. 
    # Note Spark's Parquet file reader requires replacement of 's3' with 's3a'
    resolved_offline_store_s3_location = \
        feature_group_info['OfflineStoreConfig']['S3StorageConfig']['ResolvedOutputS3Uri']
    offline_store_s3a_uri = resolved_offline_store_s3_location.replace("s3:", "s3a:")

    # Read the offline store into a dataframe
    feature_store_df = spark.read.parquet(offline_store_s3a_uri)
    
    # Filter out deleted records, if any
    fs_active_df = feature_store_df.filter(~feature_store_df.is_deleted)
    
    # Determine min and max time of query timestamps
    minmax_time = entity_df.agg(sql_min("query_time"), sql_max("query_time")).collect()
    min_time, max_time = minmax_time[0]["min(query_time)"], minmax_time[0]["max(query_time)"]
    
    # Remove all rows that are outside of our time window, allowing for a buffer of staleness days
    filtered_df = fs_active_df.filter(
        (fs_active_df[event_time_name] <= max_time) & 
        (datediff(lit(min_time), fs_active_df[event_time_name]) <= allowed_staleness_days))
    
    # Join on record id between the input entity dataframe and the feature history dataframe
    joined_df = filtered_df.join(entity_df, 
                              filtered_df[record_id_name] == entity_df[record_id_name], 'inner')\
                                .drop(entity_df[record_id_name])

    # Filter out data from after query time to remove future data leakage
    # Also filter out data that is beyond our allowed staleness window (days before each query time)
    drop_future_and_stale_df = joined_df.filter(
        (joined_df[event_time_name] <= entity_df.query_time)
        & (datediff(entity_df.query_time, joined_df[event_time_name]) <= allowed_staleness_days))

    # Group by composite key (to uniquely identify the combination of an entity id and a query timestamp),
    # and keep only the very latest remaining feature vector coming closest to the input timestamp.
    # Use write time as a tie breaker to account for any more recent backfills or data corrections.
    latest = drop_future_and_stale_df.rdd.map(lambda x: (f'{x[record_id_name]}-{x.query_time}', x))\
                .reduceByKey(
                    lambda x, y: x if (x[event_time_name], x.write_time) > 
                                      (y[event_time_name], y.write_time) else y).values()
    latest_df = latest.toDF(drop_future_and_stale_df.schema)
    
    # Clean up excess columns
    if remove_extra_columns:
        cols_to_drop = ('api_invocation_time','write_time','is_deleted',
                        record_id_name,'query_time',event_time_name,
                        'year','month','day','hour')
        latest_df = latest_df.drop(*cols_to_drop)
    
    # Return results of point in time query
    return latest_df

### Use a handful of transactions from our transactions history as a base 
Note that we select a pair of entity identifiers, `consumer_id` and `cc_num` to drive corresponding 
queries against feature value history for those entities. We also add a monotonically increasing temporary
identifier to the query dataset. This will let us do an accurate final join of the results of each
entity-specific point-in-time query result into a combined training dataset containing multiple 
feature vectors for each transaction.

In [None]:
sample_count = 10
my_entity_df = transactions_df.select('consumer_id','cc_num','event_time','amount','fraud_label')\
                    .orderBy(col('event_time').desc()).limit(sample_count)
my_entity_cols = ['consumer_id','cc_num','query_time','amount','fraud_label']
my_entity_df = my_entity_df.toDF(*my_entity_cols)
my_entity_df = my_entity_df.withColumn("_tmp_id", monotonically_increasing_id())
my_entity_df.show(10)

### Do a point-in-time correct query to retrieve Consumer features for each transaction

In [None]:
cons_df = get_historical_feature_values_one_fg(CONS_FEATURE_GROUP, my_entity_df, spark)
cons_df.show(sample_count, False)

### Do a point-in-time correct query to retrieve Credit Card features for each transaction

In [None]:
card_df = get_historical_feature_values_one_fg(CARD_FEATURE_GROUP, my_entity_df, spark)
card_df.show(sample_count, False)

### Join the feature vectors from each feature group to form the final training dataset

In [None]:
# Drop id's, as they aren't used for training
cols_to_drop = ('cc_num','consumer_id')

# In this example, the column names are not unique. Add a prefix so they can be distinct in the training dataset
new_card_col_names = ('card_num_trans_last_7d', 'card_avg_amt_last_7d', 
                 'card_num_trans_last_1d', 'card_avg_amt_last_1d', 'amount', 'fraud_label', '_tmp_id')
card_df = card_df.drop(*cols_to_drop).toDF(*new_card_col_names)
card_df.show(sample_count)

new_cons_col_names = ('cons_num_trans_last_7d', 'cons_avg_amt_last_7d', 
                 'cons_num_trans_last_1d', 'cons_avg_amt_last_1d', 'amount', 'fraud_label', '_tmp_id')
cons_df = cons_df.drop(*cols_to_drop).toDF(*new_cons_col_names)
cons_df.show(sample_count)

In [None]:
# Join the feature vectors from each entity into a single training dataset
training_df = card_df.drop('fraud_label').drop('amount').join(cons_df, "_tmp_id", "outer").drop("_tmp_id").fillna(0)
training_df.show(sample_count)