### Introduction

In [2]:
# Deprecated prior art with "tilted loss function": https://towardsdatascience.com/deep-quantile-regression-c85481548b5a

### Configuration

In [4]:
# Eager
import tensorflow as tf
tf.enable_eager_execution()

In [5]:
# fix random seed for reproducibility
import random
import scipy
import numpy as np
import keras

RANDOM_SEED = 0
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
tf.set_random_seed(RANDOM_SEED)

tf.logging.set_verbosity(tf.logging.ERROR)

In [6]:
# Versioning
print("SCIPY=", scipy.__version__)
print("NUMPY=", np.__version__)
print("TENSORFLOW=", tf.__version__)
print("KERAS=", keras.__version__)

## Credentials

In [8]:
# Get credential to read 
key=dbutils.secrets.get(scope="application-secrets", key="enterprisedatalakeprodsas")

container_list = [
  'customer',
  'customerorder',
  'foundation',
  'pricing',
  'sales',
  'cdaworkspace'
]

# Setup access to read from any container in container list
for container in container_list:
  spark.conf.set(
    "fs.azure.sas.{container}.enterprisedatalakeprod.blob.core.windows.net".format(container=container),
    key
  )

from pyspark.sql import SparkSession

# Create sparkContext
spark = SparkSession.builder.master("local[*]").getOrCreate()
sc = spark.sparkContext
sc._jsc.hadoopConfiguration().set("fs.azure.sas.cdaworkspace.enterprisedatalakeprod.blob.core.windows.net", key)

## Read

In [10]:
import json

def view_avro(view_name, file_location):
    df = spark.read.format('avro').option("inferSchema", "true").load(file_location)
    df.createOrReplaceTempView(view_name)
    shape = (df.count(), len(df.columns))
    print("DATAFRAME_SHAPE=", shape)
    return df, shape
    
def view_orc(view_name, file_location):
    df = spark.read.format('orc').option("inferSchema", "true").load(file_location)
    df.createOrReplaceTempView(view_name) 
    shape = (df.count(), len(df.columns))
    print("DATAFRAME_SHAPE=", shape)
    return df, shape
  
def view_header_csv(view_name, file_location):
    df = spark.read.format('csv').option("inferSchema", "true").option("header", 'true').load(file_location)
    df.createOrReplaceTempView(view_name)  
    shape = (df.count(), len(df.columns))
    print("DATAFRAME_SHAPE=", shape)
    return df, shape
  
def view_blob_csv(view_name, file_name):
    file_path = 'wasbs://cdaworkspace@enterprisedatalakeprod.blob.core.windows.net/cda/prod/'
    df = spark.read.format('csv').option("inferSchema", "true").option("header", 'true').load(file_path+file_name)
    df.createOrReplaceTempView(view_name)  
    shape = (df.count(), len(df.columns))
    print("DATAFRAME_SHAPE=", shape)
    return df, shape  
  
def view_schema_csv(view_name, file_location, table_schema):
  df = spark.read.format('csv').schema(table_schema).load(file_location)
  df.createOrReplaceTempView(view_name)    
  shape = (df.count(), len(df.columns))
  print("DATAFRAME_SHAPE=", shape)
  return df, shape 

def view_file(file_name):
  file_path = 'wasbs://cdaworkspace@enterprisedatalakeprod.blob.core.windows.net/cda/prod/'
  text_rdd = sc.textFile(file_path+file_name)
  return text_rdd

In [11]:
%fs ls

path,name,size
dbfs:/FileStore/,FileStore/,0
dbfs:/databricks/,databricks/,0
dbfs:/databricks-datasets/,databricks-datasets/,0
dbfs:/databricks-results/,databricks-results/,0
dbfs:/delta/,delta/,0
dbfs:/ml/,ml/,0
dbfs:/mnt/,mnt/,0
dbfs:/tmp/,tmp/,0
dbfs:/user/,user/,0


## Training Dataset

#### Purchase

In [14]:
purchase_regression_df, purchase_regression_shape = view_blob_csv('purchase_regression_view', 'LTV_BF_US_Regression_sales.csv') 

In [15]:
purchase_regression_df.printSchema()

In [16]:
display(purchase_regression_df.describe())

summary,customer_key,spend_6mo_sls,repeat_spend_6mo_sls,item_qty_6mo_sls,spend_12mo_sls,repeat_spend_12mo_sls,item_qty_12mo_sls,spend_24mo_sls,repeat_spend_24mo_sls,item_qty_24mo_sls,onsale_qty_6mo_sls,onsale_qty_12mo_sls,onsale_qty_24mo_sls,num_txns_6mo_sls,num_txns_12mo_sls,num_txns_24mo_sls,repeat_num_txns_6mo_sls,repeat_num_txns_12mo_sls,repeat_num_txns_24mo_sls,spend_6mo_men_sls,item_qty_6mo_men_sls,spend_6mo_women_sls,item_qty_6mo_women_sls,spend_6mo_accessories_sls,item_qty_6mo_accessories_sls,spend_6mo_plcb_sls,item_qty_6mo_plcb_sls,spend_12mo_plcb_sls,item_qty_12mo_plcb_sls,spend_24mo_plcb_sls,item_qty_24mo_plcb_sls,num_plcb_txns_6mo_sls,num_plcb_txns_12mo_sls,num_plcb_txns_24mo_sls,spend_6mo_sls_sb,item_qty_6mo_sls_sb,spend_12mo_sls_sb,item_qty_12mo_sls_sb,spend_24mo_sls_sb,item_qty_24mo_sls_sb,onsale_qty_6mo_sls_sb,onsale_qty_12mo_sls_sb,onsale_qty_24mo_sls_sb,num_txns_6mo_sls_sb,num_txns_12mo_sls_sb,num_txns_24mo_sls_sb,spend_6mo_plcb_sls_sb,item_qty_plcb_6mo_sls_sb,spend_12mo_plcb_sls_sb,item_qty_plcb_12mo_sls_sb,spend_24mo_plcb_sls_sb,item_qty_plcb_24mo_sls_sb,num_plcb_txns_6mo_sls_sb,num_plcb_txns_12mo_sls_sb,num_plcb_txns_24mo_sls_sb,spend_6mo_onl_sls,item_qty_6mo_onl_sls,spend_12mo_onl_sls,item_qty_12mo_onl_sls,spend_24mo_onl_sls,item_qty_24mo_onl_sls,num_txns_6mo_onl_sls,num_txns_12mo_onl_sls,num_txns_24mo_onl_sls,spend_6mo_rtl_sls,item_qty_6mo_rtl_sls,spend_12mo_rtl_sls,item_qty_12mo_rtl_sls,spend_24mo_rtl_sls,item_qty_24mo_rtl_sls,num_txns_6mo_rtl_sls,num_txns_12mo_rtl_sls,num_txns_24mo_rtl_sls,stores_shopped_6mo_sls,stores_shopped_12mo_sls,stores_shopped_24mo_sls,customer_duration_sb,days_first_pur_sb,days_last_pur_sb,customer_duration_sb_onl,days_first_pur_sb_onl,days_last_pur_sb_onl,customer_duration_sb_rtl,days_first_pur_sb_rtl,days_last_pur_sb_rtl,customer_duration,days_first_pur,days_last_pur,customer_duration_onl,days_first_pur_onl,days_last_pur_onl,customer_duration_rtl,days_first_pur_rtl,days_last_pur_rtl,days_on_books,card_status,mapped_source,net_txn_amt,validation_spend_sls,validation_item_qty_sls,validation_num_txns_sls,segment_flags
count,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696,6492696,6492696.0,6492696.0,6492696.0,6492696.0,6492696
mean,195510649.6123,47.991817372949725,2.3076863278369157,5.68607292861858,84.27506896363094,9.43248529424602,6.734093679557923,153.65171647028055,71.99389938632052,8.148729819079739,2.828251319352329,3.485142045249316,4.256150529951614,0.7188837759160469,1.2550513230588494,2.286569993225208,0.2744287394321268,0.5931203778087929,1.286569530466202,16.95563877317821,0.8896636484787442,27.421670233766875,1.7067847588472094,3.4647073850374053,0.3754423936925399,16.902894266111776,8.061920711389162,29.87642543867856,10.18429340204908,57.51325359449921,13.15848444955321,0.249577655280327,0.4409940618764322,0.8320679959573373,137.87026244104493,16.78336294746558,255.0169460498067,25.158532577005083,502.27293904408754,41.59375663711039,7.3200567247918205,10.663294769588594,16.109683329002525,2.592796021117277,4.8146322744435786,9.412570136473066,61.82156433167344,21.897134727396477,112.66986833358406,34.33904238279827,227.757311401916,59.08027624943179,1.1202436799338604,2.055020042207546,4.115213809804964,9.915344282253775,5.708649104251279,16.604729160276428,6.615364639470138,24.096246959664768,7.1313902121625325,0.6686031383285629,1.0967228402374514,1.5620150403159985,38.076473090699615,5.246596332589702,67.67033980337115,6.169093387674354,129.55546951062723,7.628984509858963,0.6376354607361825,1.1277278461892664,2.144373729899643,0.4597418943143421,0.7261405784005419,1.200829898711368,359.4754884221109,514.2806363083013,154.8054035817951,253.08301793694315,450.5595634439517,197.47654550700855,331.3218918899517,503.5371295480276,172.21551066324844,110.71702811230112,376.0726912644674,265.35566315216624,67.8745076222507,280.03722132635136,212.16271370410067,106.71503466307678,382.4999508670045,275.78491620392776,790.6070854206511,,,46.06590967913401,52.043700371933255,9.46991484541504,2.604446218672872,
stddev,108574918.52506672,108.26365583092776,40.53191933238156,9.123645379971965,168.77922847735843,71.94363540850364,13.254832108257297,260.3013495895526,245.01708249092647,17.157263742676808,5.183460159046788,7.366713182195041,9.41130550257932,1.3126106455535775,2.0656516151773103,3.324041608441432,1.0821815266395929,1.9084304648003103,3.32404157872857,59.11554547980335,3.6187883332464894,74.79966992766245,4.155328815293347,14.01775076968558,2.155289532878394,77.09156593316153,13.564614404352971,122.85585482972976,20.05317914872301,205.65736080608983,26.74211151292139,0.968814793501282,1.5959587616134585,2.709530300049558,502.9749722616986,50.62527112323654,828.7335365131261,80.54608214006053,1963.7355527388168,173.96433255104628,25.763916086589813,40.75561180591716,86.13790101786151,6.313466259638712,11.346368379193477,29.78926557492318,235.80301122353248,28.17959477463077,410.1868746293906,47.270832264636624,787.4238423214722,86.2104967029388,3.3669308183576643,6.040484944993357,11.706541410547258,54.36429885427284,8.070565002926712,82.16936985854151,11.07168466660805,112.22262718390274,13.45690137174947,1.1865763734892452,1.767991372123605,2.432439868248233,90.24476726517112,8.414593117876938,142.08587309587978,12.185000750488298,226.3698364867285,15.872246718472562,1.1684718055114158,1.8479433528964369,3.0698997759248208,0.5856940337550701,0.6506624970032084,0.5831623718545268,256.36717817483014,196.08319435906975,174.15403344144838,254.758862204911,215.56637220430824,193.47628086658588,253.50789484435532,196.0766317053427,180.3339275963341,188.37913140600165,216.63056265414977,207.4821891597118,143.00205160788929,199.75625962369355,182.65801942677112,185.56089546417851,215.73366353660563,208.71989742833995,1022.2425809957432,,,143.62819720649688,161.55771841218316,16.59788751263663,3.4986519171588664,
min,115.0,-15.49,-3116.709999999994,1.0,-65.37,-3116.709999999994,1.0,-41.89999999999998,-6259.769999999952,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,-22.450000000000003,0.0,-19.17,0.0,-89.47,0.0,-15.49,1.0,-65.37,1.0,-41.89999999999998,1.0,0.0,0.0,0.0,-837.7099999999998,1.0,-837.7099999999998,1.0,-182.47,1.0,0.0,0.0,0.0,0.0,0.0,0.0,-837.7099999999998,1.0,-837.7099999999998,1.0,-182.47,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,-15.49,1.0,-65.36999999999999,1.0,-41.900000000000006,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,NON_CARD,Card Program,-2013.219999999999,-381.15,1.0,1.0,
max,395764197.0,48379.44000000747,11595.829999999876,,87440.25999999311,24663.100000000464,,125408.16999996906,125108.36999996909,,,,,,,,,,,45035.84000000515,,13505.399999999983,,8873.699999999784,,44209.67000000584,,80462.16000000204,,117470.24999997408,,,,,798741.6299995629,,924263.3599994062,,3229154.5300163096,,,,,,,,41661.48000000013,,61343.17999999848,,128544.580000015,,,,,13625.179999999946,,39581.22000000034,,62734.9799999998,,,,,48379.44000000788,,87440.25999999461,,125108.36999996733,,,,,,,,,,,,,,,,,,,,,,,,,,,NON_CARD,wifi,68442.80000000008,69086.37000000008,,,svsbsc_rtl


#### Return

In [18]:
return_regression_df, return_regression_shape = view_blob_csv('return_regression_view', 'LTV_BF_US_Regression_return.csv') 

In [19]:
return_regression_df.printSchema()

In [20]:
display(return_regression_df.describe())

summary,customer_key,spend_6mo_rtn,item_qty_6mo_rtn,spend_12mo_rtn,item_qty_12mo_rtn,spend_24mo_rtn,item_qty_24mo_rtn,onsale_qty_6mo_rtn,onsale_qty_12mo_rtn,onsale_qty_24mo_rtn,num_txns_6mo_rtn,num_txns_12mo_rtn,num_txns_24mo_rtn,spend_6mo_men_rtn,item_qty_6mo_men_rtn,spend_6mo_women_rtn,item_qty_6mo_women_rtn,spend_6mo_accessories_rtn,item_qty_6mo_accessories_rtn,spend_6mo_rtn_sb,item_qty_6mo_rtn_sb,spend_12mo_rtn_sb,item_qty_12mo_rtn_sb,spend_24mo_rtn_sb,item_qty_24mo_rtn_sb,onsale_qty_6mo_rtn_sb,onsale_qty_12mo_rtn_sb,onsale_qty_24mo_rtn_sb,num_txns_6mo_rtn_sb,num_txns_12mo_rtn_sb,num_txns_24mo_rtn_sb,spend_6mo_plcb_rtn_sb,item_qty_6mo_plcb_rtn_sb,spend_12mo_plcb_rtn_sb,item_qty_12mo_plcb_rtn_sb,spend_24mo_plcb_rtn_sb,item_qty_24mo_plcb_rtn_sb,num_txns_plcb_6mo_rtn_sb,num_txns_plcb_12mo_rtn_sb,num_txns_plcb_24mo_rtn_sb,card_status,mapped_source,validation_spend_rtn,validation_item_qty_rtn,validation_num_txns_rtn,segment_flags
count,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696.0,6492696,6492696,6492696.0,6492696.0,6492696.0,6492696
mean,195510649.6123,4.708404126731888,-2.980244666635949,8.40134604177409,-3.3146505898760803,14.592674751758318,-3.641937767731288,-1.7281758282265125,-1.9701903950797928,-2.2013091985532323,0.4576324251841163,0.8191872831580077,1.4233378857169028,1.4002834092341634,-0.319226703649935,3.1142475714257687,-1.010166416900463,0.191251189028405,-0.0688961642456482,21.44860727346696,-5.16679163441343,41.117603901355885,-6.802793772390708,81.08601634358664,-9.711233971739356,-2.1587257970702938,-2.697071530868953,-3.506062262906617,0.9901280644442326,1.8845892114059053,3.718425103642344,11.292422512928496,-6.179977747820405,21.45118816435828,-8.498250690591968,43.04313638279044,-12.663896059926362,0.5135512897487255,0.9712690332363816,1.9526981498094036,,,5.977790692806211,-3.746976951486265,1.7878381389834248,
stddev,108574918.52506664,29.0342673897853,3.900788011296098,43.00482727367231,4.909152060872182,63.83688979869292,6.04529619150927,2.579970390211228,3.21020205939559,3.884539740444677,0.9311629935008234,1.4034250784399205,2.071590801828941,14.193494897957445,1.2441767230907006,22.684407086232348,2.522245193505793,2.992095396911814,0.4996231150874025,113.80013292382468,9.064576923719224,198.9017046118462,13.724753573654247,372.20710186219503,22.74466488930473,4.289009465518547,5.889909202487188,8.816305231628718,2.231288380499776,3.9434749186843607,7.463944514725719,88.17572325170333,10.707812855558938,157.05466886387063,16.841849833889956,299.078318252441,28.658899599633777,1.868438947936844,3.3580119842372635,6.491710129636812,,,39.54018424155804,5.732527011675379,1.972166764043451,
min,115.0,0.0,-1.0,0.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,-1.0,0.0,-1.0,0.0,-1.0,0.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,-1.0,0.0,0.0,0.0,NON_CARD,Card Program,0.0,-1.0,1.0,
max,395764197.0,6640.179999999973,,9865.57999999996,,14315.300000000014,,,,,,,,3863.7099999999873,,3799.33999999999,,662.2499999999998,,36946.82000000004,,61158.26999999984,,146856.58999999976,,,,,,,,29303.650000000074,,45779.29999999994,,72656.64000000013,,,,,NON_CARD,wifi,10245.859999999966,,,svsbsc_rtl


#### Labels

In [22]:
ground_truth_col_names = ['validation_spend_sls', 'validation_num_txns_sls', 'validation_item_qty_sls',
                    'validation_spend_rtn', 'validation_num_txns_rtn', 'validation_item_qty_rtn',
                    'net_txn_amt'] 
purchase_label_name = 'validation_spend_sls'
return_label_name = 'validation_spend_rtn'

#### Join Purchase and Return

In [24]:
to_join_cols = ['customer_key', return_label_name]
return_regression_select_df = return_regression_df.select(*to_join_cols)

In [25]:
overall_df = purchase_regression_df.join(return_regression_select_df, on='customer_key', how='inner')   

In [26]:
overall_df.printSchema()

In [27]:
display(overall_df)

customer_key,spend_6mo_sls,repeat_spend_6mo_sls,item_qty_6mo_sls,spend_12mo_sls,repeat_spend_12mo_sls,item_qty_12mo_sls,spend_24mo_sls,repeat_spend_24mo_sls,item_qty_24mo_sls,onsale_qty_6mo_sls,onsale_qty_12mo_sls,onsale_qty_24mo_sls,num_txns_6mo_sls,num_txns_12mo_sls,num_txns_24mo_sls,repeat_num_txns_6mo_sls,repeat_num_txns_12mo_sls,repeat_num_txns_24mo_sls,spend_6mo_men_sls,item_qty_6mo_men_sls,spend_6mo_women_sls,item_qty_6mo_women_sls,spend_6mo_accessories_sls,item_qty_6mo_accessories_sls,spend_6mo_plcb_sls,item_qty_6mo_plcb_sls,spend_12mo_plcb_sls,item_qty_12mo_plcb_sls,spend_24mo_plcb_sls,item_qty_24mo_plcb_sls,num_plcb_txns_6mo_sls,num_plcb_txns_12mo_sls,num_plcb_txns_24mo_sls,spend_6mo_sls_sb,item_qty_6mo_sls_sb,spend_12mo_sls_sb,item_qty_12mo_sls_sb,spend_24mo_sls_sb,item_qty_24mo_sls_sb,onsale_qty_6mo_sls_sb,onsale_qty_12mo_sls_sb,onsale_qty_24mo_sls_sb,num_txns_6mo_sls_sb,num_txns_12mo_sls_sb,num_txns_24mo_sls_sb,spend_6mo_plcb_sls_sb,item_qty_plcb_6mo_sls_sb,spend_12mo_plcb_sls_sb,item_qty_plcb_12mo_sls_sb,spend_24mo_plcb_sls_sb,item_qty_plcb_24mo_sls_sb,num_plcb_txns_6mo_sls_sb,num_plcb_txns_12mo_sls_sb,num_plcb_txns_24mo_sls_sb,spend_6mo_onl_sls,item_qty_6mo_onl_sls,spend_12mo_onl_sls,item_qty_12mo_onl_sls,spend_24mo_onl_sls,item_qty_24mo_onl_sls,num_txns_6mo_onl_sls,num_txns_12mo_onl_sls,num_txns_24mo_onl_sls,spend_6mo_rtl_sls,item_qty_6mo_rtl_sls,spend_12mo_rtl_sls,item_qty_12mo_rtl_sls,spend_24mo_rtl_sls,item_qty_24mo_rtl_sls,num_txns_6mo_rtl_sls,num_txns_12mo_rtl_sls,num_txns_24mo_rtl_sls,stores_shopped_6mo_sls,stores_shopped_12mo_sls,stores_shopped_24mo_sls,customer_duration_sb,days_first_pur_sb,days_last_pur_sb,customer_duration_sb_onl,days_first_pur_sb_onl,days_last_pur_sb_onl,customer_duration_sb_rtl,days_first_pur_sb_rtl,days_last_pur_sb_rtl,customer_duration,days_first_pur,days_last_pur,customer_duration_onl,days_first_pur_onl,days_last_pur_onl,customer_duration_rtl,days_first_pur_rtl,days_last_pur_rtl,days_on_books,card_status,mapped_source,net_txn_amt,validation_spend_sls,validation_item_qty_sls,validation_num_txns_sls,segment_flags,validation_spend_rtn
8440,73.77999999999999,0.0,4.0,73.77999999999999,0.0,4.0,85.16,58.77999999999999,5.0,2.0,2.0,3.0,1.0,1.0,2.0,0.0,0.0,1.0,43.2,2.0,30.58,2.0,0.0,0.0,73.77999999999999,4.0,73.77999999999999,4.0,85.16,5.0,1.0,1.0,2.0,1079.6600000000003,51.0,1933.7000000000007,89.0,4890.689999999991,191.0,17.0,37.0,74.0,8.0,16.0,41.0,1079.6600000000003,51.0,1933.7000000000007,89.0,4890.689999999991,191.0,8.0,16.0,41.0,0.0,,0.0,,0.0,,,,,73.77999999999999,4.0,73.77999999999999,4.0,85.15999999999998,5.0,1.0,1.0,2.0,1.0,1.0,1.0,705.0,718.0,13.0,240.0,566.0,326.0,705.0,718.0,13.0,494.0,529.0,35.0,,,,494.0,529.0,35.0,529.0,NON_CARD,cic,0.0,0.0,,,mvmbsc_rtl,0.0
13248,0.0,0.0,,0.0,0.0,,178.93,148.94,7.0,,,7.0,0.0,0.0,3.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,591.6700000000001,36.0,1162.4800000000005,68.0,2260.789999999998,131.0,26.0,35.0,57.0,10.0,18.0,31.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,0.0,,0.0,,148.94,6.0,0.0,0.0,2.0,0.0,,0.0,,29.99,1.0,0.0,0.0,1.0,0.0,0.0,1.0,708.0,716.0,8.0,708.0,716.0,8.0,632.0,681.0,49.0,205.0,592.0,387.0,28.0,415.0,387.0,0.0,592.0,592.0,1374.0,NON_CARD,cic,0.0,0.0,,,mvmbmc,0.0
31156,23.39,0.0,2.0,23.39,0.0,2.0,42.38,23.39,4.0,1.0,1.0,2.0,1.0,1.0,2.0,0.0,0.0,1.0,18.9,1.0,0.0,0.0,4.49,1.0,23.39,2.0,23.39,2.0,23.39,2.0,1.0,1.0,1.0,188.48,13.0,240.96,17.0,279.64000000000004,20.0,11.0,12.0,14.0,7.0,8.0,9.0,145.57999999999998,10.0,145.57999999999998,10.0,184.26,13.0,4.0,4.0,5.0,0.0,,0.0,,0.0,,,,,23.39,2.0,23.39,2.0,42.38,4.0,1.0,1.0,2.0,1.0,1.0,2.0,493.0,494.0,1.0,160.0,331.0,171.0,493.0,494.0,1.0,321.0,492.0,171.0,,,,321.0,492.0,171.0,1468.0,NON_CARD,cic,0.0,0.0,,,mvmbsc_rtl,0.0
32912,171.78000000000003,0.0,8.0,267.25000000000006,0.0,13.0,889.0600000000003,776.1400000000003,47.0,5.0,5.0,28.0,2.0,5.0,13.0,1.0,4.0,12.0,0.0,0.0,163.29,7.0,8.49,1.0,171.78000000000003,8.0,267.25000000000006,13.0,889.0600000000003,47.0,2.0,5.0,13.0,1630.4699999999998,63.0,2756.659999999997,128.0,5982.219999999987,279.0,10.0,28.0,78.0,18.0,38.0,86.0,1480.010000000001,60.0,2544.599999999998,120.0,5746.189999999986,267.0,16.0,35.0,82.0,125.44,6.0,125.44,6.0,599.2400000000001,25.0,1.0,1.0,5.0,46.34,2.0,141.80999999999997,7.0,289.8200000000001,22.0,1.0,4.0,8.0,1.0,1.0,1.0,701.0,709.0,8.0,589.0,630.0,41.0,701.0,709.0,8.0,657.0,701.0,44.0,435.0,611.0,176.0,657.0,701.0,44.0,2967.0,NON_CARD,cic,5.390000000000015,194.12000000000003,7.0,2.0,mvmbmc,188.73
33013,0.0,0.0,,0.0,0.0,,123.97,92.77,4.0,,,4.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,0.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,445.1000000000001,29.0,863.4600000000005,63.0,1631.830000000001,108.0,21.0,38.0,52.0,11.0,23.0,41.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,0.0,,0.0,,123.97,4.0,0.0,0.0,1.0,0.0,,0.0,,0.0,,,,,,,,657.0,693.0,36.0,557.0,593.0,36.0,636.0,693.0,57.0,0.0,437.0,437.0,0.0,437.0,437.0,,,,437.0,NON_CARD,online order,94.78,128.38,4.0,2.0,mvmbsc_onl,33.6
39473,16.0,0.0,1.0,16.0,0.0,1.0,16.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,16.0,1.0,0.0,0.0,16.0,1.0,16.0,1.0,16.0,1.0,1.0,1.0,1.0,0.0,4.0,19.379999999999995,7.0,54.63,15.0,0.0,2.0,2.0,1.0,2.0,4.0,0.0,4.0,19.379999999999995,7.0,54.63,15.0,1.0,2.0,4.0,16.0,1.0,16.0,1.0,16.0,1.0,1.0,1.0,1.0,0.0,,0.0,,0.0,,,,,,,,558.0,597.0,39.0,558.0,597.0,39.0,,,,0.0,4.0,4.0,0.0,4.0,4.0,,,,4.0,NON_CARD,online order,0.0,0.0,,,mvmbsc_onl,0.0
40436,0.0,0.0,,0.0,0.0,,163.45000000000005,61.98000000000005,7.0,,,4.0,0.0,0.0,2.0,0.0,0.0,1.0,0.0,0.0,0.0,,0.0,0.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,0.0,,0.0,,64.72,4.0,,,4.0,0.0,0.0,1.0,0.0,,0.0,,64.72,4.0,0.0,0.0,1.0,0.0,,0.0,,61.98,2.0,0.0,0.0,1.0,0.0,,0.0,,101.46999999999998,5.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,456.0,456.0,,,,0.0,456.0,456.0,28.0,484.0,456.0,0.0,456.0,456.0,0.0,484.0,484.0,456.0,NON_CARD,online order,0.0,0.0,,,mvmbmc,0.0
55474,19.98,0.0,2.0,19.98,0.0,2.0,19.98,0.0,2.0,2.0,2.0,2.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,19.98,2.0,0.0,0.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,0.0,,0.0,,10.0,2.0,,,0.0,0.0,0.0,1.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,0.0,,0.0,,0.0,,,,,19.98,2.0,19.98,2.0,19.98,2.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,524.0,524.0,,,,0.0,524.0,524.0,0.0,181.0,181.0,,,,0.0,181.0,181.0,180.0,NON_CARD,cic,0.0,0.0,,,mvmbsc_rtl,0.0
71575,5.879999999999999,0.0,1.0,5.879999999999999,0.0,1.0,80.68999999999998,45.62999999999998,6.0,0.0,0.0,1.0,1.0,1.0,3.0,0.0,0.0,2.0,0.0,0.0,5.879999999999999,1.0,0.0,0.0,5.879999999999999,1.0,5.879999999999999,1.0,80.68999999999998,6.0,1.0,1.0,3.0,337.0300000000002,47.0,751.5600000000004,84.0,1391.760000000001,150.0,17.0,29.0,53.0,15.0,31.0,53.0,302.8100000000002,44.0,717.3400000000004,81.0,1357.5400000000009,147.0,14.0,30.0,52.0,0.0,,0.0,,35.06,3.0,0.0,0.0,1.0,5.879999999999999,1.0,5.879999999999999,1.0,45.63,3.0,1.0,1.0,2.0,1.0,1.0,1.0,724.0,725.0,1.0,712.0,725.0,13.0,678.0,679.0,1.0,432.0,502.0,70.0,0.0,502.0,502.0,364.0,434.0,70.0,3819.0,NON_CARD,cic,33.19999999999999,33.19999999999999,5.0,2.0,mvmbmc,0.0
76584,0.0,0.0,,0.0,0.0,,42.48,0.0,2.0,,,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,,0.0,0.0,0.0,23.0,3.0,333.57000000000005,30.0,495.17,46.0,0.0,10.0,11.0,1.0,2.0,6.0,0.0,,310.57000000000005,27.0,412.2300000000001,39.0,0.0,1.0,2.0,0.0,,0.0,,0.0,,,,,0.0,,0.0,,42.48,2.0,0.0,0.0,1.0,0.0,0.0,1.0,450.0,626.0,176.0,,,,450.0,626.0,176.0,0.0,487.0,487.0,,,,0.0,487.0,487.0,2612.0,NON_CARD,cic,0.0,0.0,,,mvmbsc_rtl,0.0


## Model Training

In [29]:
import warnings
warnings.filterwarnings('ignore')

# General libraries
import sys,os,time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import seaborn as sb
from statistics import median 

# ML libraries 
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import RobustScaler
from sklearn.pipeline import Pipeline
from sklearn.utils import resample
from sklearn.ensemble import RandomForestRegressor as rfr
from keras import layers
import tensorflow as tf
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn import preprocessing
from keras import regularizers
from keras.layers import Dropout
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from keras.constraints import max_norm
from keras.models import model_from_json
from keras.layers.advanced_activations import LeakyReLU
from keras.callbacks import ModelCheckpoint
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import median_absolute_error
from sklearn.preprocessing import MinMaxScaler
import keras.backend as K

In [30]:
from pyspark.sql.types import DoubleType
import pyspark.sql.functions as F


CAST_COLUMN_TAG = '_'
  
def na_and_fill(df):
  return df \
        .dropna() 
         # TODO: revisit this assumption vs fillna(0)?
  
def cast_column(df):  
  for col_name in df_in.columns:
    cast_col_name = col_name + CAST_COLUMN_TAG
    df = df \
          .withColumn(cast_col_name, df[col_name].cast(DoubleType())) \
          .drop(col_name)
  return df

def drop_null(df): 
  null_counts = df.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns]) \
                      .collect()[0] \
                      .asDict()
  to_drop = [k for k, v in null_counts.items() if v > 0]
  df = df.drop(*to_drop)
  
  return df, to_drop

# TODO: Cap outliers?

def filter_negative(df):
  for col_name in ground_truth_col_names:
    try:
      df = df.filter(df[col_name] > 0)
    except:
      pass
  return df


In [31]:
df_in = overall_df.drop('customer_key')

df = na_and_fill(df_in)
df = cast_column(df)
df = filter_negative(df)
df_final, to_drop = drop_null(df)

print("COLUMN_TO_DROP=", to_drop)

assert df_final.count() > 0, "DATAFRAME_LENGTH_ERROR"  

In [32]:
train_df, test_df, val_df = df_final.randomSplit([.7, .25, .05])
train_df.cache()
test_df.cache()
val_df.cache()

train_df.count(), test_df.count(), val_df.count() 

In [33]:
display(train_df)

spend_6mo_sls_,repeat_spend_6mo_sls_,spend_12mo_sls_,repeat_spend_12mo_sls_,spend_24mo_sls_,repeat_spend_24mo_sls_,spend_6mo_men_sls_,spend_6mo_women_sls_,spend_6mo_accessories_sls_,spend_6mo_plcb_sls_,spend_12mo_plcb_sls_,spend_24mo_plcb_sls_,spend_6mo_sls_sb_,spend_12mo_sls_sb_,spend_24mo_sls_sb_,spend_6mo_plcb_sls_sb_,spend_12mo_plcb_sls_sb_,spend_24mo_plcb_sls_sb_,spend_6mo_onl_sls_,spend_12mo_onl_sls_,spend_24mo_onl_sls_,spend_6mo_rtl_sls_,spend_12mo_rtl_sls_,spend_24mo_rtl_sls_,net_txn_amt_,validation_spend_sls_,validation_spend_rtn_
0.0,-16.450000000000017,0.0,-16.450000000000017,0.0,-16.450000000000017,0.0,0.0,0.0,0.0,0.0,0.0,280.5100000000001,375.80000000000007,375.8000000000002,264.06000000000006,359.35000000000014,359.3500000000001,0.0,0.0,0.0,0.0,0.0,0.0,3.5999999999999943,3.5999999999999943,0.0
0.0,0.0,0.0,0.0,0.0,-47.19,0.0,0.0,0.0,0.0,0.0,0.0,165.38,616.5000000000001,980.5700000000004,159.89,541.0600000000001,905.1300000000005,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,-6.439999999999998,0.0,0.0,0.0,0.0,0.0,0.0,15.349999999999994,34.129999999999995,60.76000000000005,15.349999999999994,34.129999999999995,60.76000000000005,0.0,0.0,0.0,0.0,0.0,0.0,107.45000000000002,107.45000000000002,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


#### Ground truth columns renamed

In [35]:
ground_truth_col_renamed = [x+CAST_COLUMN_TAG for x in ground_truth_col_names]
purchase_label_renamed = purchase_label_name + CAST_COLUMN_TAG
return_label_renamed = return_label_name + CAST_COLUMN_TAG

#### Input pipeline

In [37]:
# https://dwgeek.com/python-pyspark-iterator-how-to-create-and-use.html/
def get_dict_iter(df):
  return df.rdd.map(lambda r: r.asDict()).toLocalIterator()

In [38]:
# test
sample_iter = get_dict_iter(val_df)

In [39]:
def get_df_batch(df_iter, batch_size=100):
  batch = []
  for i in range(batch_size):
    try:
      next_is = next(df_iter)
      batch.append(next_is)
      # print(next_is)
    except:
      pass
  return batch

def get_tf_dataset_batch(df_iter, batch_size):
  batch_dict = get_df_batch(df_iter, batch_size)
  batch_pd = pd.DataFrame.from_dict(batch_dict) 
  
  purchase_label_col = batch_pd.pop(purchase_label_renamed)
  return_label_col = batch_pd.pop(return_label_renamed)
  net_label_col = purchase_label_col - return_label_col
  dict_labels = (dict(batch_pd), net_label_col)
  
  tf_dataset = tf.data.Dataset.from_tensor_slices(dict_labels)
  return tf_dataset.batch(batch_size), batch_pd

In [40]:
# test
sample_ds, sample_pd = get_tf_dataset_batch(sample_iter, batch_size=1)

for feature_batch, label_batch in sample_ds.take(1):
  print('EVERY_FEATURE=', list(feature_batch.keys()))
  print('FEATURE_BATH repeat_spend_12mo_sls_=', feature_batch['repeat_spend_12mo_sls_'])
  print('TARGET_BATCH=', label_batch )

In [41]:
# test
from tensorflow import feature_column

sample_batch = next(iter(sample_ds))[0]
col = feature_column.numeric_column('repeat_spend_12mo_sls_')
sample_layer = tf.keras.layers.DenseFeatures(col)
layer_numpy = sample_layer(sample_batch).numpy()

print("layer_numpy=", layer_numpy)
assert len(layer_numpy)>0, "BATCHING_FAILED"

In [42]:
from tensorflow.keras.layers import Layer

def avoid_snooping_truth(xy_pd):
  x_pd = xy_pd
  
  for col in ground_truth_col_renamed:
    try:
      x_pd.drop(col)
    except:
      pass

  return x_pd


def get_numeric_features(x_pd):
  feature_columns = []
  numeric_columns = x_pd.select_dtypes(include=['float64']).columns
  for header in numeric_columns:
    feature_columns.append(feature_column.numeric_column(header))
  return feature_columns

  
def get_features(xy_pd):
  x_pd = avoid_snooping_truth(xy_pd) 

  feature_columns = []
  feature_columns += get_numeric_features(x_pd)
  # feature_columns += get_indicator_features(x_pd)

  feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
  # print("feature_columns", feature_columns)
  
  return feature_layer


def get_indicator_features(x_pd):
  feature_columns = []
  indicator_columns = x_pd.select_dtypes(include=['object']).columns
  
  for col in indicator_columns:
    category_feature = feature_column.categorical_column_with_vocabulary_list(col, indicator_columns)
    onehot_feature = feature_column.indicator_column(category_feature)
    feature_columns.append(onehot_feature)
  return feature_columns


In [43]:
# test
sample_layer = get_features(sample_pd)

sample_layer

In [44]:
def neural_net_model(input_layer, 
                     hidden_layer_neuron_count=256, num_hidden_layers=4):
    layers = [input_layer]
    for i in range(num_hidden_layers):
      layers.append(tf.keras.layers.Dense(hidden_layer_neuron_count, kernel_initializer='normal', activation='relu'))
    layers.append(tf.keras.layers.Dense(1, kernel_initializer='normal', activation='linear'))
    model = tf.keras.Sequential(layers)
    return model

In [45]:
from collections import defaultdict

from keras.callbacks import History 
from keras.callbacks import EarlyStopping


def train_model(model_key, 
                loss_function,
                train_df, test_df, 
                learning_rate,
                layer_neurons_nonlinear, num_layers_nonlinear,
                whole_epochs=1, epoch_per_batch=1000 # relies on EarlyStopping
               ):

  num_batches = 10
  batch_size = int(train_df.count()/num_batches)

  test_iter = get_dict_iter(test_df)
  test_ds, test_pd = get_tf_dataset_batch(test_iter, batch_size)
  feature_layer = get_features(test_pd)
  
  training_model = neural_net_model(feature_layer, layer_neurons_nonlinear, num_layers_nonlinear)
  # keras_wrapper = keras.wrappers.scikit_learn.KerasRegressor(training_model)
  training_model.compile(loss=loss_function, 
                optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate), 
                metrics=[loss_function, 'mean_absolute_error'])

  training_history = []
      
  while whole_epochs > 0:
    train_iter = get_dict_iter(train_df) # .orderBy(rand()
    
    while num_batches > 0: 
      print("\n\n*******************\n", 
            "MODEL_KEY=", model_key, 
            "REMAINING_EPOCHS=", whole_epochs,
            "REMAINING_BATCHES=", num_batches,  
            "\n*******************\n")

      train_ds, train_pd = get_tf_dataset_batch(train_iter, batch_size)    

      history = History()
      training_model.fit(train_ds,
                validation_data=test_ds,  
                shuffle=True,         
                epochs=epoch_per_batch,
                callbacks=[history,
                           EarlyStopping(monitor='val_loss', mode='auto', patience=5)])  # min_delta=1
      training_history.append(history)

      num_batches-=1
    whole_epochs-=1
    
  return training_model, training_history

In [46]:
def train_models(hyper_params):
  param_model = {}
  param_history = {}

  for loss in hyper_params['loss_function']:
    for height in hyper_params['layer_neurons_nonlinear']:
      for depth in hyper_params['num_layers_nonlinear']:
        for rate in hyper_params['learning_rate']:
          for whole_epochs in hyper_params['whole_dataset_epochs']:
            for epoch_per_batch in hyper_params['epoch_per_batch']:
              try:
                model_key = '-' + 'depth=' + str(depth) + '-' + 'height=' + str(height) +  \
                  '-' + 'whole_epochs=' + str(whole_epochs) + '-' + 'epoch_per_batch=' + str(epoch_per_batch) + \
                  '-' + 'loss=' + loss + '-' + 'rate=' + str(rate)
                model, history = train_model(model_key, 
                                             loss,
                                             train_df, test_df, 
                                             rate,
                                             layer_neurons_nonlinear=height, num_layers_nonlinear=depth,
                                             whole_epochs=whole_epochs, epoch_per_batch=epoch_per_batch)
                param_model[model_key] = model
                param_history[model_key] = history
              except Exception as e:
                print("TRAINING_EXCEPTION", model_key, e)

  return param_model, param_history 

In [47]:
def choose_model(param_model, param_history):
  COMPARISON_BATCH_SIZE = 250000
  test_iter = get_dict_iter(test_df)
  test_ds, test_pd = get_tf_dataset_batch(test_iter, batch_size=COMPARISON_BATCH_SIZE)

  chosen_key = None
  chosen_model = None
  chosen_test_loss = None
  chosen_test_error = None
  for model_key, model in param_model.items():
    try:
      test_evaluation = model.evaluate(test_ds)
      test_loss = test_evaluation[0] 
      test_objective_error = test_evaluation[1]
      print("MODEL_EVALUATED=", model_key, "TRAIN_LOSS=", test_loss, "TRAIN_ERROR=", test_objective_error)
      if chosen_model == None or (test_loss < chosen_test_loss):
        chosen_key = model_key
        chosen_model = model
        chosen_test_loss = test_loss
        chosen_test_error = test_objective_error
        print("MODEL_SELECTED=", model_key)
    except Exception as e:
      print(model_key, "FAILED_SELECTION", e)
      pass
  return chosen_key, chosen_model, chosen_test_loss, chosen_test_error

In [48]:
def measure_model_performance(chosen_key, chosen_model):
  try:
    MEASUREMENT_BATCH_SIZE = 50000
    val_iter = get_dict_iter(val_df)
    val_ds, val_pd = get_tf_dataset_batch(val_iter, batch_size=MEASUREMENT_BATCH_SIZE)

    val_evaluation = chosen_model.evaluate(val_ds)
    loss = val_evaluation[0]
    error = val_evaluation[1]
    print("CHOSEN_MODEL=", chosen_key, "VALIDATION_ERROR=", error)
  except Exception as e:
    print(e)
    pass

### Run Model Training and Selection

#### Loss = mean_absolute_percentage_error

In [51]:
hyper_params = {}
hyper_params['loss_function'] = ['mean_absolute_percentage_error'] # 'mean_squared_error', mean_absolute_error, mean_squared_logarithmic_error, cosine_similarity, huber_loss, log_cosh
hyper_params['num_layers_nonlinear'] = [1, 2, 3] 
hyper_params['layer_neurons_nonlinear'] = [6, 12, 18] 
hyper_params['learning_rate'] = [10**-12]    

hyper_params['whole_dataset_epochs'] = [1] # 10
hyper_params['epoch_per_batch'] = [5]


In [52]:
param_model, param_history = train_models(hyper_params)

In [53]:
chosen_key, chosen_model, chosen_test_loss, chosen_test_error = choose_model(param_model, param_history)

In [54]:
measure_model_performance(chosen_key, chosen_model)

### Plot

#### Plot Helper Classes

In [57]:
from collections import defaultdict


class model_history():
  
  def __init__(self):
    self.loss_history = defaultdict(list)
    self.error_history = defaultdict(list)
    self.percentage_error_history = defaultdict(list)
  
  def add(self, key, values):
    if key in ["loss", "val_loss"]:
      self.loss_history[key].extend(values)
    if key in ["mean_absolute_error", "val_mean_absolute_error"]:
      self.error_history[key].extend(values)
    if key in ["mean_absolute_percentage_error", "val_mean_absolute_percentage_error"]:
      self.percentage_error_history[key].extend(values)
      
# test
h = model_history()
h.add('loss', [1,2,3])
h.add('loss', [3,2,1])
h.loss_history

In [58]:
model_train_history = {}

for model_key in param_history.keys():
  try:
    history_arr = param_history[model_key]
    for batch_id in range(len(history_arr)):
      history_batch = history_arr[batch_id]
      # print(history_batch.history.keys())
      for metric_name, metric_history in history_batch.history.items():
        if model_key not in model_train_history:
          this_history = model_history()
        else:
          this_history = model_train_history[model_key]
        this_history.add(metric_name, metric_history)
        model_train_history[model_key] = this_history 
  except:
    pass

In [59]:
def plot_learning_curves(model_key, train_history):
  try:
    print("model_key=", model_key)
    pd.DataFrame(train_history).plot(figsize=(8, 5), title=model_key)
    plt.grid(True)
    # plt.gca().set_ylim(0, 1)
    # plt.show()
  except:
    pass

#### Plot the Selected Model

In [61]:
plot_learning_curves(chosen_key, model_train_history[chosen_key].loss_history)

In [62]:
plot_learning_curves(chosen_key, model_train_history[chosen_key].percentage_error_history)

In [63]:
  plot_learning_curves(chosen_key, model_train_history[chosen_key].error_history)

#### Plot All Models

In [65]:
for key in model_train_history.keys():
  plot_learning_curves(key, model_train_history[key].loss_history)

In [66]:
for key in model_train_history.keys():
  plot_learning_curves(key, model_train_history[key].error_history)

In [67]:
for key in model_train_history.keys():
  plot_learning_curves(key, model_train_history[key].percentage_error_history)