In [0]:
%sh
vmstat 3

procs -----------memory---------- ---swap-- -----io---- -system-- ------cpu-----
 r  b   swpd   free   buff  cache   si   so    bi    bo   in   cs us sy id wa st
 6  0      0 5401268      0 1520444    0    1   598 22550 2558 4042 34  8 31 23  3
 3  0      0 5370912      0 1515128    0   25  1236  6661 5146 10769 85  6  5  4  0
 4  0      0 5333764      0 1517144    0   13  1797   437 4992 10956 89  7  3  1  0
 0  0      0 5301644      0 1509332    0   21   349   888 5041 10829 48  5 46  0  0
 4  0      0 5276952      0 1532440    0   13     0    75 6065 11392  9  3 89  0  0
 0  0      0 5276420      0 1532440   11    0    11  1024 6162 11251  6  1 92  0  0
 0  0      0 5262156      0 1546144    0   13     0  1139 5340 10497 28  2 70  0  0
 0  0      0 5260752      0 1546144    0    0     0   253 5912 11132 13  0 86  0  0
 4  0      0 5259116      0 1546144    0    0     0   316 5949 11212  8  1 91  0  0
 0  0      0 5257460      0 1546144    0    0     0     0 5648 11083 15  2 82  0  0

In [0]:
orders_sdf = spark.read.csv('/FileStore/tables/orders.csv', header=True, inferSchema=True)
trains_sdf = spark.read.csv('/FileStore/tables/order_products_train.csv', header=True, inferSchema=True)
products_sdf = spark.read.csv('/FileStore/tables/products.csv', header=True, inferSchema=True)
aisles_sdf = spark.read.csv('/FileStore/tables/aisles.csv', header=True, inferSchema=True)
depts_sdf = spark.read.csv('/FileStore/tables/departments.csv', header=True, inferSchema=True)

In [0]:
%fs 
cp /FileStore/tables/order_products_prior.zip file:/home/order_products_prior.zip 

In [0]:
import pandas as pd

priors_pdf = pd.read_csv('/home/order_products_prior.zip', compression='zip', header=0, sep=',', quotechar='"')
priors_sdf = spark.createDataFrame(priors_pdf)
del priors_pdf # 메모리 절약을 위해 pandas dataframe삭제

In [0]:
orders_sdf.createOrReplaceTempView("orders")
priors_sdf.createOrReplaceTempView("priors")
trains_sdf.createOrReplaceTempView("trains")
products_sdf.createOrReplaceTempView("products")
aisles_sdf.createOrReplaceTempView("aisles")
depts_sdf.createOrReplaceTempView("depts")

In [0]:
spark.catalog.listTables()

Out[5]: [Table(name='aisles', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='depts', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='orders', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='priors', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='products', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='trains', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/order_priors_prods

In [0]:
%sql
drop table if exists order_priors_prods;

-- priors와 orders를 조인
create table order_priors_prods
as
select a.order_id, a.product_id, a.add_to_cart_order, a.reordered
  , b.user_id, b.eval_set, b.order_number, b.order_dow, b.order_hour_of_day
  , b.days_since_prior_order
from priors a, orders b
where a.order_id = b.order_id;

num_affected_rows,num_inserted_rows


### 상품 레벨의 분석 속성에 기반한 상품 분석 테이블 생성
* PK 는 상품코드(product_id)이며 이전 EDA에서 분석한 속성들로 상품 분석 테이블 생성.

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/prd_mart

In [0]:
%sql
drop table if exists prd_mart;

create table prd_mart
as
with 
-- with 구문 첫번째 집합. product_id 레벨로 group by 하여 상품별 서로 다른 개별 사용자 비율을 추출한 결과에 상품명과 상품 중분류명 알기 위해 products와 aisles로 조인
order_prods_grp as
(
  select a.product_id 
    -- ## 상품별 재주문 속성
    , sum(case when reordered=1 then 1 else 0 end) as prd_reordered_cnt -- 상품별 재 주문 건수
    , sum(case when reordered=0 then 1 else 0 end) as prd_no_reordered_cnt -- 상품별 재 주문 하지 않은 건수 
    , avg(reordered) prd_avg_reordered -- 상품별 재 주문 비율
    -- ## 상품별 고유 사용자 및 이전 주문이후 걸린 일자 속성. 
    , count(distinct user_id) prd_unq_usr_cnt -- 상품별 고유 사용자 건수
    , count(*)  prd_total_cnt -- 상품별 건수
    , count(distinct user_id)/count(*) as prd_usr_ratio -- 상품별 전체 건수 대비 고유 사용자 비율
    , max(c.aisle_id) aisle_id -- 상품 중분류 코드 
    , nvl(avg(days_since_prior_order), 0) as prd_avg_prior_days -- 평균 이전 주문이후 걸린 일자, null인 경우 0으로 변환. 
    , nvl(min(days_since_prior_order), 0) as prd_min_prior_days -- 최소 이전 주문이후 걸린 일자, null인 경우 0으로 변환. 
    , nvl(max(days_since_prior_order), 0) as prd_max_prior_days -- 최대 이전 주문이후 걸린 일자, null인 경우 0으로 변환. 
    from order_priors_prods a, products b, aisles c
  where a.product_id = b.product_id 
  and b.aisle_id = c.aisle_id
  group by a.product_id
),
-- with 구문 두번째 집합. product_id 레벨로 group by 하여 상품별 서로 다른 개별 사용자 비율을 추출한 결과에 product_name과 중분류명, 대분류명을 알기 위해 aisles와 dept로 조인. 
order_aisles_grp as
(
  select c.aisle_id as aisle_id 
     , count(distinct a.user_id) aisle_distinct_usr_cnt -- 상품 중분류별 고유 사용자 건수
     , count(*)  aisle_total_cnt -- 상품 중분류 건수
     , count(distinct a.user_id)/count(*) as aisle_usr_ratio -- 상품 중분류 건수 대비 고유 사용자 건수 비율
  from order_priors_prods a, products b, aisles c
  where a.product_id = b.product_id 
  and b.aisle_id = c.aisle_id
  group by c.aisle_id
),
-- with 구문 세번째 집합. 상품 중분류 별 개별 사용자 비율과 상품별 개별 사용자 비율 차이 추출. 
order_prd_grp_aisle as
(
  select product_id, prd_reordered_cnt,  prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt, prd_total_cnt, prd_usr_ratio
    , prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days-- 상품별 속성들
    , b.aisle_distinct_usr_cnt, b.aisle_total_cnt, b.aisle_usr_ratio -- 상품 중분류별 속성들 
    , a.prd_usr_ratio - b.aisle_usr_ratio as usr_ratio_diff -- 상품 중분류 별 개별 사용자 비율과 상품별 개별 사용자 비율 차이
  from order_prods_grp a, order_aisles_grp b
  where a.aisle_id = b.aisle_id
) 
-- end of with 절
select * from order_prd_grp_aisle

num_affected_rows,num_inserted_rows


In [0]:
%sql
select * from prd_mart limit 10

product_id,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
47908,0,3,0.0,3,3,1.0,22.33333333333333,7.0,30.0,85357,575881,0.1482198579220359,0.851780142077964
9856,0,3,0.0,3,3,1.0,10.666666666666666,7.0,17.0,85357,575881,0.1482198579220359,0.851780142077964
3832,0,2,0.0,2,2,1.0,21.5,13.0,30.0,85357,575881,0.1482198579220359,0.851780142077964
12120,0,3,0.0,3,3,1.0,6.333333333333333,6.0,7.0,85357,575881,0.1482198579220359,0.851780142077964
10536,3,7,0.3,7,10,0.7,21.625,9.0,30.0,85357,575881,0.1482198579220359,0.551780142077964
33171,0,8,0.0,8,8,1.0,10.571428571428571,3.0,30.0,85357,575881,0.1482198579220359,0.851780142077964
29994,13,4,0.7647058823529411,4,17,0.2352941176470588,7.666666666666667,1.0,21.0,85357,575881,0.1482198579220359,0.0870742597250228
28551,28,7,0.8,7,35,0.2,6.8,0.0,29.0,85357,575881,0.1482198579220359,0.051780142077964
46860,0,5,0.0,5,5,1.0,11.25,3.0,22.0,85357,575881,0.1482198579220359,0.851780142077964
1431,9,14,0.391304347826087,14,23,0.6086956521739131,11.428571428571429,1.0,30.0,85357,575881,0.1482198579220359,0.4604757942518771


In [0]:
%sql
--49676
select count(*) from prd_mart

count(1)
49676


In [0]:
import pyspark.sql.functions as F

prd_mart_sdf = spark.sql("select * from prd_mart")

display(prd_mart_sdf.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in prd_mart_sdf.columns]))

product_id,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
0,0,0,0,0,0,0,0,0,0,0,0,0,0


### 사용자 레벨의 분석 속성에 기반한 사용자 분석 테이블 생성
* PK 는 사용자아이디(user_id)이며 이전 EDA에서 분석한 속성들로 사용자 분석 테이블 생성.
* 추후에 예측 데이터를 만들기 위해 order_id가 필요. 이를 위해 train과 test용 orders 데이터와 user_id로 조인하여 order_id 추출 필요.
* orders 테이블은 user_id레벨로 m 이지만 eval_set이 train과 test일 경우는 user_id 레벨로 1이므로 조인시 user_mart 테이블 레벨이 변하지 않음.

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/user_mart_01

In [0]:
%sql
drop table if exists user_mart_01;

create table user_mart_01
as
select user_id 
  , count(*) as usr_total_cnt -- 사용자별 주문 건수
  -- 주문 건수 관련 속성 추출. 
  , count(distinct product_id) prd_uq_cnt  -- 사용자별 고유 상품 주문 건수
  , count(distinct order_id) order_uq_cnt -- 사용자별 고유 주문 건수
  , count(*)/count(distinct order_id) as usr_avg_prd_cnt -- 사용자별 1회 주문시 평균 주문 상품 건수
  , count(*)/count(distinct product_id) as usr_avg_uq_prd_cnt -- 사용자별 1회 주문시 평균 고유 주문 상품 건수
  , count(distinct product_id)/count(*) as usr_uq_prd_ratio --사용자별 총 상품 건수 대비 고유 상품 건수 비율
  -- ### reordered 관련 속성 추출. ###
  , sum(reordered) usr_reord_cnt -- 사용자별 reordered된 상품 건수
  , sum(case when reordered = 0 then 1 else 0 end) as usr_no_reord_cnt -- 사용자별 reorder 하지 않은 상품 건수. count(*) - sum(reoredred)와 동일. 
  , avg(reordered) usr_reordered_avg -- 사용자별 reordered 비율
  -- ### days_since_prior_order 관련 속성 추출. ###
  , avg(days_since_prior_order) usr_avg_prior_days
  , max(days_since_prior_order) usr_max_prior_days
  , min(days_since_prior_order) usr_min_prior_days
  -- ### order_dow, order_hour_of_day 관련 속성 추출. ###
  , avg(order_dow) usr_avg_order_dow
  , avg(order_hour_of_day) usr_avg_order_hour_of_day
  -- 사용자별 최대 order_number
  , max(order_number) as usr_max_order_number
from order_priors_prods a group by user_id

num_affected_rows,num_inserted_rows


In [0]:
%sql
select count(*) from user_mart_01

count(1)
206209


In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/user_mart

In [0]:
%sql
drop table if exists user_mart;

-- orders는 eval_set이 train/test일 경우 한개의 user_id가 한개의 order_id를 가짐. 때문에 train/test인 경우 조인키값 user_id로 1레벨이 됨.
-- order_priors_prods에 있는 모든 user_id는 orders의 모든 user_id와 동일. orders는 user_id별로 여러건의 order가 있고, 이들중 마지막 order를 train또는 test로 할당하기 때문
-- 따라서 user_mart_01과 eval_set이 train과 test인 orders를 user_id로 조인하면 1:1 조인이 되고 user_mart_01의 집합 레벨의 변화가 없음. outer join을 하지 않아도 됨. 
create table user_mart
as
select a.*, b.order_id, b.eval_set, b.days_since_prior_order
from user_mart_01 a, orders b
where a.user_id = b.user_id
and b.eval_set in ('train', 'test')   

num_affected_rows,num_inserted_rows


In [0]:
%sql
select * from user_mart limit 10

user_id,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,order_id,eval_set,days_since_prior_order
1,59,18,10,5.9,3.2777777777777777,0.3050847457627119,41,18,0.6949152542372882,20.25925925925926,30.0,0.0,2.6440677966101696,10.542372881355933,10,1187899,train,14.0
2,195,102,14,13.928571428571429,1.911764705882353,0.5230769230769231,93,102,0.4769230769230769,15.967032967032967,30.0,3.0,2.005128205128205,10.44102564102564,14,1492625,train,30.0
3,88,33,12,7.333333333333333,2.6666666666666665,0.375,55,33,0.625,11.487179487179487,21.0,7.0,1.0113636363636365,16.352272727272727,12,2774568,test,11.0
4,18,17,5,3.6,1.0588235294117647,0.9444444444444444,1,17,0.0555555555555555,15.357142857142858,21.0,0.0,4.722222222222222,13.11111111111111,5,329954,test,30.0
5,37,23,4,9.25,1.608695652173913,0.6216216216216216,14,23,0.3783783783783784,14.5,19.0,10.0,1.6216216216216215,15.72972972972973,4,2196797,train,6.0
6,14,12,3,4.666666666666667,1.1666666666666667,0.8571428571428571,2,12,0.1428571428571428,7.8,12.0,6.0,3.857142857142857,17.0,3,1528013,test,22.0
7,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,525192,train,6.0
8,49,36,3,16.333333333333332,1.3611111111111112,0.7346938775510204,13,36,0.2653061224489796,30.0,30.0,30.0,4.204081632653061,2.4489795918367347,3,880375,train,10.0
9,76,58,3,25.33333333333333,1.3103448275862069,0.7631578947368421,18,58,0.2368421052631578,24.26086956521739,30.0,6.0,2.6973684210526314,14.263157894736842,3,1094988,train,30.0
10,143,94,5,28.6,1.5212765957446808,0.6573426573426573,49,94,0.3426573426573426,20.746376811594203,30.0,12.0,4.013986013986014,16.902097902097903,5,1822501,train,30.0


In [0]:
%sql
--206209
select count(*) from user_mart

count(1)
206209


In [0]:
%sql
select count(*)
from orders b
where b.eval_set in ('train', 'test')  

count(1)
206209


In [0]:
%sql
select * from orders where user_id = 1

order_id,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
2539329,1,prior,1,2,8,
2398795,1,prior,2,3,7,15.0
473747,1,prior,3,3,12,21.0
2254736,1,prior,4,4,7,29.0
431534,1,prior,5,4,15,28.0
3367565,1,prior,6,2,7,19.0
550135,1,prior,7,1,9,20.0
3108588,1,prior,8,1,14,14.0
2295261,1,prior,9,1,16,0.0
2550362,1,prior,10,4,8,30.0


In [0]:
%sql
select * from orders where user_id = 3

order_id,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
1374495,3,prior,1,1,14,
444309,3,prior,2,3,19,9.0
3002854,3,prior,3,3,16,21.0
2037211,3,prior,4,2,18,20.0
2710558,3,prior,5,0,17,12.0
1972919,3,prior,6,0,16,7.0
1839752,3,prior,7,0,15,7.0
3225766,3,prior,8,0,17,7.0
3160850,3,prior,9,0,16,7.0
676467,3,prior,10,3,16,17.0


### 사용자 + 상품 레벨의 분석 속성에 기반한 사용자+상품 분석 테이블 생성
* PK 는 사용자아이디(user_id)+상품코드(product_id)이며 이전 EDA에서 분석한 속성들로 사용자+상품 분석 테이블 생성 테이블 생성.
* 앞에서 만든 prd_mart, user_mart를 사용자+상품 분석 테이블과 조인하여 상품관련 속성, 사용자 관련 속성을 결합함.

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/up_mart

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/up_mart_01

In [0]:
%sql
select * from order_priors_prods where user_id = 1

order_id,product_id,add_to_cart_order,reordered,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
473747,196,1,1,1,prior,3,3,12,21.0
473747,12427,2,1,1,prior,3,3,12,21.0
473747,10258,3,1,1,prior,3,3,12,21.0
473747,25133,4,0,1,prior,3,3,12,21.0
473747,30450,5,0,1,prior,3,3,12,21.0
3108588,12427,1,1,1,prior,8,1,14,14.0
3108588,196,2,1,1,prior,8,1,14,14.0
3108588,10258,3,1,1,prior,8,1,14,14.0
3108588,25133,4,1,1,prior,8,1,14,14.0
3108588,46149,5,0,1,prior,8,1,14,14.0


In [0]:
%sql
drop table if exists up_mart;
drop table if exists up_mart_01;

create table up_mart
as
with 
-- 사용자+상품 레벨로 group by 하여 속성 추출. 
up_grp as
(
SELECT user_id, product_id
    , count(*) up_cnt  -- 사용자의 개별 상품별 주문 건수
    , sum(reordered) up_reord_cnt -- 사용자의 개별 상품별 reorder 건수
    , sum(case when reordered=0 then 1 else 0 end) up_no_reord_cnt
    , avg(reordered) up_reoredered_avg -- 사용자의 개별 상품 주문별 reorder비율 
    , max(order_number) up_max_ord_num -- 사용자+상품레벨에서 가장 큰 order_number. order_number는 사용자 별로 주문을 수행한 일련번호를 순차적으로 가짐. 
    , min(order_number) up_min_ord_num -- 사용자+상품레벨에서 가장 작은 order_number
    , avg(add_to_cart_order) up_avg_cart --사용자 상품레벨에서 보통 cart에 몇번째로 담는가?
    , avg(days_since_prior_order) as up_avg_prior_days
    , max(days_since_prior_order) as up_max_prior_days
    , min(days_since_prior_order) as up_min_prior_days
    , avg(order_dow) as up_avg_ord_dow
    , avg(order_hour_of_day) as up_avg_ord_hour
FROM order_priors_prods GROUP BY user_id, product_id
)
-- end of with 절 
-- 사용자 레벨로 group by 한 user_mart 테이블과 조인하여 사용자 레벨 속성과 사용자+상품 레벨 속성의 비율을 추출. 
select a.* 
  , a.up_cnt/b.usr_total_cnt as up_usr_ratio -- 사용자별 전체 주문 건수 대비 사용자+상품 주문 건수 비율
  , a.up_reord_cnt/b.usr_reord_cnt as up_usr_reord_ratio -- 사용자별 전체 재주문 건수 대비 사용자+상품 재주문 건수 비율
  , b.usr_reord_cnt
  , b.usr_max_order_number - a.up_max_ord_num as up_usr_ord_num_diff -- 사용자의 가장 최근 주문(가장 큰 주문번호)에서 현 상품 주문번호가 어느정도 이후에 있는지
from up_grp a, user_mart b
where a.user_id = b.user_id

num_affected_rows,num_inserted_rows


In [0]:
%sql
select * from user_mart where usr_reord_cnt is null

user_id,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,order_id,eval_set,days_since_prior_order


In [0]:
%sql
select * from up_mart where usr_reord_cnt is null;

user_id,product_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,usr_reord_cnt,up_usr_ord_num_diff


In [0]:
%sql
select * from up_mart where up_usr_reord_ratio is null

user_id,product_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,usr_reord_cnt,up_usr_ord_num_diff
162723,46865,1,0,1,0.0,3,3,3.0,3.0,3.0,3.0,4.0,8.0,0.03125,,0,0
85639,5550,1,0,1,0.0,2,2,2.0,6.0,6.0,6.0,1.0,9.0,0.0277777777777777,,0,2
146277,13948,1,0,1,0.0,3,3,1.0,30.0,30.0,30.0,6.0,10.0,0.2,,0,1
133015,43814,1,0,1,0.0,1,1,6.0,,,,3.0,9.0,0.0588235294117647,,0,2
182613,45832,1,0,1,0.0,3,3,2.0,1.0,1.0,1.0,4.0,11.0,0.0833333333333333,,0,0
151874,33793,1,0,1,0.0,1,1,8.0,,,,1.0,14.0,0.0454545454545454,,0,4
185468,47766,1,0,1,0.0,2,2,7.0,10.0,10.0,10.0,5.0,15.0,0.0555555555555555,,0,1
68323,27801,1,0,1,0.0,2,2,8.0,30.0,30.0,30.0,2.0,10.0,0.027027027027027,,0,2
166616,2566,1,0,1,0.0,1,1,4.0,,,,6.0,22.0,0.1111111111111111,,0,2
184240,16823,1,0,1,0.0,2,2,1.0,30.0,30.0,30.0,1.0,14.0,0.0416666666666666,,0,2


In [0]:
%sql
select * from
(
SELECT user_id, product_id
    , count(*) up_cnt  -- 사용자의 개별 상품별 주문 건수
    , sum(reordered) up_reord_cnt -- 사용자의 개별 상품별 reorder 건수
    , sum(case when reordered=0 then 1 else 0 end) up_no_reord_cnt
    , avg(reordered) up_reoredered_avg -- 사용자의 개별 상품 주문별 reorder비율 
    , max(order_number) up_max_ord_num -- 사용자+상품레벨에서 가장 큰 order_number. order_number는 사용자 별로 주문을 수행한 일련번호를 순차적으로 가짐. 
    , min(order_number) up_min_ord_num -- 사용자+상품레벨에서 가장 작은 order_number
    , avg(add_to_cart_order) up_avg_cart --사용자 상품레벨에서 보통 cart에 몇번째로 담는가?
    , avg(days_since_prior_order) as up_avg_prior_days
    , max(days_since_prior_order) as up_max_prior_days
    , min(days_since_prior_order) as up_min_prior_days
    , avg(order_dow) as up_avg_ord_dow
    , avg(order_hour_of_day) as up_avg_ord_hour
FROM order_priors_prods GROUP BY user_id, product_id
) where up_reord_cnt is null

user_id,product_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour


In [0]:
%sql
select * from up_mart where up_usr_reord_ratio is null

user_id,product_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,usr_reord_cnt,up_usr_ord_num_diff
162723,46865,1,0,1,0.0,3,3,3.0,3.0,3.0,3.0,4.0,8.0,0.03125,,0,0
85639,5550,1,0,1,0.0,2,2,2.0,6.0,6.0,6.0,1.0,9.0,0.0277777777777777,,0,2
146277,13948,1,0,1,0.0,3,3,1.0,30.0,30.0,30.0,6.0,10.0,0.2,,0,1
133015,43814,1,0,1,0.0,1,1,6.0,,,,3.0,9.0,0.0588235294117647,,0,2
182613,45832,1,0,1,0.0,3,3,2.0,1.0,1.0,1.0,4.0,11.0,0.0833333333333333,,0,0
151874,33793,1,0,1,0.0,1,1,8.0,,,,1.0,14.0,0.0454545454545454,,0,4
185468,47766,1,0,1,0.0,2,2,7.0,10.0,10.0,10.0,5.0,15.0,0.0555555555555555,,0,1
68323,27801,1,0,1,0.0,2,2,8.0,30.0,30.0,30.0,2.0,10.0,0.027027027027027,,0,2
166616,2566,1,0,1,0.0,1,1,4.0,,,,6.0,22.0,0.1111111111111111,,0,2
184240,16823,1,0,1,0.0,2,2,1.0,30.0,30.0,30.0,1.0,14.0,0.0416666666666666,,0,2


In [0]:
%sql
-- 위 create table up_mart가 오래 걸릴 경우(10분 이상) 아래와 같이 sql을 분할하여 수행. 
-- group by user_id, product_id 로 속성 추출하여 up_mart_01 생성. 
create table up_mart_01
as
SELECT user_id, product_id
    , count(*) usr_prd_cnt  -- 사용자의 개별 상품별 주문 건수
    , sum(reordered) up_reord_cnt -- 사용자의 개별 상품별 reorder 건수
    , sum(case when reordered=0 then 1 else 0 end) up_no_reord_cnt
    , avg(reordered) up_reoredered_avg -- 사용자의 개별 상품 주문별 reorder비율 
    , max(order_number) up_max_ord_num -- 사용자+상품레벨에서 가장 큰 order_number. order_number는 사용자 별로 주문을 수행한 일련번호를 순차적으로 가짐. 
    , min(order_number) up_min_ord_num -- 사용자+상품레벨에서 가장 큰 order_number
    , avg(add_to_cart_order) up_avg_cart --사용자 상품레벨에서 보통 cart에 몇번째로 담는가?
    , avg(days_since_prior_order) as up_avg_prior_days
    , max(days_since_prior_order) as up_max_prior_days
    , min(days_since_prior_order) as up_min_prior_days
    , avg(order_dow) as up_avg_ord_dow
    , avg(order_hour_of_day) as up_avg_ord_hour
FROM order_priors_prods GROUP BY user_id, product_id; 

-- up_mart_01과 user_mart 를 user_id로 조인하여 추가 속성 생성. 
create table up_mart
as
select a.* 
  , a.up_cnt/b.usr_total_cnt as up_usr_ratio -- 사용자별 전체 주문 건수 대비 사용자+상품 주문 건수 비율
  , a.up_reord_cnt/b.usr_reord_cnt as up_usr_reord_ratio -- 사용자별 전체 재주문 건수 대비 사용자+상품 재주문 건수 비율
  , b.usr_max_order_number - a.up_max_ord_num as up_usr_ord_num_diff -- 사용자의 가장 최근 주문(가장 큰 주문번호)에서 현 사용자+상품 주문번호가 어느정도 이후에 있는지
from up_mart_01 a, user_mart b
where a.user_id = b.user_id

In [0]:
%sql
--13307953
select count(*) from up_mart

In [0]:
%sql
-- up_mart에서 user_mart로, user_id로 join이 안되거나 prd_mart로, product_id로 join이 안되는 경우 추출.  
select count(*)
from up_mart a 
left outer join user_mart b
on a.user_id = b.user_id
left outer join prd_mart c
on a.product_id = c.product_id
where (b.user_id is null or c.product_id is null)

In [0]:
%sql
select * from aisles where aisle_id='Blunted'
/* 
select * from products a where product_id = 6816
select * from aisles where aisle_id='Blunted' 
*/

In [0]:
%sql
-- 현재까지 만들어진 테이블의 건수 조사 
select 'user_mart count' as gubun, count(*) from user_mart
union all
select 'prd_mart count' as gubun, count(*) from prd_mart
union all
select 'up_mart count' as gubun, count(*) from up_mart

#### 현재까지 만든 prd_mart, user_mart, up_mart를 결합하여 data_mart 생성. 
* 생성된 data_mart는 up_mart를 기준으로 prd_mart, user_mart를 조인하여 상품 분석속성, 사용자 분석속성을 결합.

In [0]:
%sql
describe up_mart

In [0]:
print(spark.sql("select * from up_mart").columns)
print(spark.sql("select * from user_mart").columns)
print(spark.sql("select * from prd_mart").columns)

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/data_mart

In [0]:
%sql
drop table if exists data_mart;

-- up_mart에 user_mart를 user_id로 조인, prd_mart는 product_id로 조인하여 개별 xxx_mart테이블의 속성들을 취합하여 data_mart 테이블 생성. 약 4분정도 걸림. 
create table data_mart
as
select 
  -- up_mart 컬럼 
  a.user_id, a.product_id, b.order_id -- 테스트 데이터 예측 데이터 제출을 위해서 order_id가 필요함. 
  , up_cnt, up_reord_cnt, up_no_reord_cnt, up_reoredered_avg, up_max_ord_num, up_min_ord_num, up_avg_cart, up_avg_prior_days, up_max_prior_days
  , up_min_prior_days, up_avg_ord_dow, up_avg_ord_hour, up_usr_ratio, up_usr_reord_ratio, up_usr_ord_num_diff
  -- user_mart 컬럼, eval_set에 train과 test용 데이터(사용자)구분
  , usr_total_cnt, prd_uq_cnt, order_uq_cnt, usr_avg_prd_cnt, usr_avg_uq_prd_cnt, usr_uq_prd_ratio, usr_reord_cnt, usr_no_reord_cnt, usr_reordered_avg, usr_avg_prior_days
  , usr_max_prior_days, usr_min_prior_days, usr_avg_order_dow, usr_avg_order_hour_of_day, usr_max_order_number, eval_set, days_since_prior_order
  -- prd_mart 컬럼
  , prd_reordered_cnt, prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt, prd_total_cnt, prd_usr_ratio, prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days
  , aisle_distinct_usr_cnt, aisle_total_cnt, aisle_usr_ratio, usr_ratio_diff
from up_mart a, user_mart b, prd_mart c
where a.user_id = b.user_id and a.product_id = c.product_id

In [0]:
%sql
-- 현재까지 생성된 테이블의 건수 조사. data_mart는 up_mart와 동일 건수 - 3 
select 'data_mart count' as gubun, count(*) from data_mart
union all 
select 'user_mart count' as gubun, count(*) from user_mart
union all
select 'prd_mart count' as gubun, count(*) from prd_mart
union all
select 'up_mart count' as gubun, count(*) from up_mart

In [0]:
%sql
select * from data_mart limit 10

### 학습과 테스트용 데이터 세트 생성. 
* order_products_train.csv(trains 테이블)는  train용으로 reordered label 값이 주어져 있음.
* trains 테이블의 pk는 order_id + product_id 이지만 실제로는 1개의 user_id에 1개의 order_id만 할당되므로 user_id + product_id로 unique함. 
* trains 테이블과 orders 테이블을 조인하여 user_id를 가져오는 order_trains_prods 테이블 생성. 
* order_trains_prods 테이블을 기준으로 data_mart에서 생성한 속성을 붙이려고 두개의 테이블을 user_id + product_id로 조인(order_trains_prods 레프트 아우터 조인)하면 많은 데이터가 조인되지 않음.  
* 조인되지 않을 경우에 data_mart에서 생성한 속성을 사용할 수 없음. 
* data_mart를 기준으로 order_trains_prods를 조인(data_mart 레프트 아우터 조인)하여 label값인 reordered를 설정하고 조인되지 않는 경우 reordered를 0으로 설정.

In [0]:
%sql
select * from trains limit 10

In [0]:
%sql
--1384617
select count(*) from trains

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/order_trains_prods

In [0]:
%sql
drop table if exists order_trains_prods;
-- order_products_train 데이터에(trains 테이블)에 user_id를 얻기 위해서 orders 테이블과 조인
-- 해당 테이블은 kaggle에서 train 용으로 제공었지만, 많은 속성(feature)로 만들어진 data_mart 테이블에 비해 적은 속성을 가지고 있음. 
create table order_trains_prods
as
select a.order_id, a.product_id, a.reordered
  , b.user_id
from trains a, orders b
where a.order_id = b.order_id

In [0]:
%sql
select count(*) from order_trains_prods

In [0]:
%sql
select * from order_trains_prods limit 10

In [0]:
%sql
-- user_id + product_id 로 중복되는 건수가 없음. 즉 user_id + product_id로 unique
select user_id, product_id, count(*) from order_trains_prods group by user_id, product_id having count(*) > 1

In [0]:
%sql
-- 555793 건이 user_id + product_id 레벨로 order_trains_prods과 data_mart와 조인되지 않음. 
select count(*) 
from
order_trains_prods a
left outer join data_mart b
on a.user_id = b.user_id and a.product_id = b.product_id
where b.product_id is null

In [0]:
%sql
-- 개별 user_id 레벨, 또는 개별 product_id 레벨로 조인이 되지 않는 건수는 거의 없음. 즉 개별 user는 동일하지만 user+상품은 prior와 train에 별도로 존재하는 경우가 많다는 의미
with
data_user_grp as
(
  select user_id from data_mart group by user_id
),
data_product_grp as
(
  select product_id from data_mart group by product_id
)
select 'only_user_id_count' as gubun, count(*) from order_trains_prods a left outer join data_user_grp b on a.user_id = b.user_id 
where b.user_id is null
union all
select 'only_product_id_count' as gubun, count(*) from order_trains_prods a left outer join data_product_grp b on a.product_id = b.product_id 
where b.product_id is null

In [0]:
%sql
-- data_mart와 user_id + product_id로 조인되지 않는 order_trains_prods의 reordered는 모두 0 임. 
select a.reordered, count(*) 
from
order_trains_prods a
left outer join data_mart b
on a.user_id = b.user_id and a.product_id = b.product_id
where b.product_id is null
group by a.reordered

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/train_data

In [0]:
print(spark.sql("select * from data_mart").columns)

In [0]:
%sql
-- 학습용 feature와 label 데이터 세트 생성. 
-- order_trains_prods를 기준으로 data_mart와 outer 조인하면 많은 데이터가 조인되지 않음. 이 경우 해당 데이터는 data_mart의 속성들을 사용할 수 없음. 
-- data_mart를 기준으로 order_trains_prods를 outer 조인하면 역시 많은 데이터가 조인되지 않음. data_mart의 속성은 여전히 사용할 수 있음. 
-- order_trains_prods를 기준으로 학습 데이터를 만들지 않고 data_mart를 기준으로 학습 데이터를 생성. 
-- order_trains_prods의 eval_set가 'train' 인 경우 user_id 레벨로 학습 데이터이므로 이를 이용하여 학습 데이터를 생성. 
-- data_mart와 order_trains_prods가 조인이 되는 경우 order_trains_prods의 reorder값을 이용하고, 조인이 되지 않는 경우는 0으로 (추후)변경
drop table if exists train_data;

create table train_data
as
select 
-- user_id, product_id, order_id -- 학습용 feature 데이터를 만들기에 user_id, product_id, order_id 와 같은 id 속성은 제외
  up_cnt, up_reord_cnt, up_no_reord_cnt, up_reoredered_avg, up_max_ord_num, up_min_ord_num, up_avg_cart, up_avg_prior_days, up_max_prior_days, up_min_prior_days
, up_avg_ord_dow, up_avg_ord_hour, up_usr_ratio, up_usr_reord_ratio, up_usr_ord_num_diff, usr_total_cnt, prd_uq_cnt, order_uq_cnt, usr_avg_prd_cnt, usr_avg_uq_prd_cnt
, usr_uq_prd_ratio, usr_reord_cnt, usr_no_reord_cnt, usr_reordered_avg, usr_avg_prior_days, usr_max_prior_days, usr_min_prior_days, usr_avg_order_dow
, usr_avg_order_hour_of_day, usr_max_order_number
--, eval_set -- eval_set 제외
, days_since_prior_order, prd_reordered_cnt, prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt
, prd_total_cnt, prd_usr_ratio, prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days, aisle_distinct_usr_cnt
, aisle_total_cnt, aisle_usr_ratio, usr_ratio_diff
, b.reordered -- label 값. order_train_prods와 조인되지 않는 경우 label을 0으로 변경. nvl(b.reordered, 0) 적용
from data_mart a left outer join order_trains_prods b
on a.user_id = b.user_id and a.product_id = b.product_id
where a.eval_set = 'train'

In [0]:
%sql
-- data_mart에서 order_trains_prods와 user_id+product_id로 조인되지 않는 건은 7645837 건. 추후에 DataFrame에서 이들 데이터의 reordered를 모두 0으로 변경. 
select 'train_data' as gubun, count(*) from train_data
union all
select 'reordered null' as gubun, count(*) from train_data where reordered is null

In [0]:
%fs
rm -r dbfs:/user/hive/warehouse/test_data

In [0]:
%sql
-- 테스트용 데이터 세트 생성. data_mart에서 eval_set이 test 인것만 추출하여 생성. reordered 필요 없음. 
drop table if exists test_data;

create table test_data
as
select 
user_id, product_id, order_id -- 테스트 데이터와 학습 데이터와 마찬가지로 id 속성값은 필요가 없지만, 추후에 kaggle 테스트 성능 결과 제출을 위해 order_id, product_id가 필요. 추후 제거. 
, up_cnt, up_reord_cnt, up_no_reord_cnt, up_reoredered_avg, up_max_ord_num, up_min_ord_num, up_avg_cart, up_avg_prior_days, up_max_prior_days, up_min_prior_days
, up_avg_ord_dow, up_avg_ord_hour, up_usr_ratio, up_usr_reord_ratio, up_usr_ord_num_diff, usr_total_cnt, prd_uq_cnt, order_uq_cnt, usr_avg_prd_cnt, usr_avg_uq_prd_cnt
, usr_uq_prd_ratio, usr_reord_cnt, usr_no_reord_cnt, usr_reordered_avg, usr_avg_prior_days, usr_max_prior_days, usr_min_prior_days, usr_avg_order_dow
, usr_avg_order_hour_of_day, usr_max_order_number
--, eval_set -- eval_set 제외
, days_since_prior_order, prd_reordered_cnt, prd_no_reordered_cnt, prd_avg_reordered, prd_unq_usr_cnt
, prd_total_cnt, prd_usr_ratio, prd_avg_prior_days, prd_min_prior_days, prd_max_prior_days, aisle_distinct_usr_cnt
, aisle_total_cnt, aisle_usr_ratio, usr_ratio_diff
--, b.reordered -- label 제외
from data_mart a where a.eval_set = 'test' -- data_mart에서 eval_set이 test 인것만 추출하여 생성. 

In [0]:
%sql
select count(*) from test_data

### 학습 데이터 전처리 및 모델 학습, 예측 평가 수행
* 학습데이터 Null값은 모두 0으로 처리

In [0]:
%fs
ls dbfs:/user/hive/warehouse/train_data

path,name,size
dbfs:/user/hive/warehouse/test_data/_delta_log/,_delta_log/,0
dbfs:/user/hive/warehouse/test_data/part-00000-85a8f62e-2c0c-49ef-840d-ff57720f38de-c000.snappy.parquet,part-00000-85a8f62e-2c0c-49ef-840d-ff57720f38de-c000.snappy.parquet,14785657
dbfs:/user/hive/warehouse/test_data/part-00001-97335b52-a307-4e4f-8e40-fda9e747583b-c000.snappy.parquet,part-00001-97335b52-a307-4e4f-8e40-fda9e747583b-c000.snappy.parquet,14804783
dbfs:/user/hive/warehouse/test_data/part-00002-e9753e7a-37e6-47e5-8cc8-c37d19f67f0b-c000.snappy.parquet,part-00002-e9753e7a-37e6-47e5-8cc8-c37d19f67f0b-c000.snappy.parquet,13745579
dbfs:/user/hive/warehouse/test_data/part-00003-28dfbeaa-1cf0-4f01-b513-8ee5bf60a3b5-c000.snappy.parquet,part-00003-28dfbeaa-1cf0-4f01-b513-8ee5bf60a3b5-c000.snappy.parquet,13961602
dbfs:/user/hive/warehouse/test_data/part-00004-9971b533-f49c-42be-a4b4-1168a0030079-c000.snappy.parquet,part-00004-9971b533-f49c-42be-a4b4-1168a0030079-c000.snappy.parquet,14144402
dbfs:/user/hive/warehouse/test_data/part-00005-68cbcdf4-c903-485d-b455-e6c6447cb2dc-c000.snappy.parquet,part-00005-68cbcdf4-c903-485d-b455-e6c6447cb2dc-c000.snappy.parquet,14926364
dbfs:/user/hive/warehouse/test_data/part-00006-07b66837-559d-41b1-8d5e-f4341410cdcf-c000.snappy.parquet,part-00006-07b66837-559d-41b1-8d5e-f4341410cdcf-c000.snappy.parquet,13978967
dbfs:/user/hive/warehouse/test_data/part-00007-df2ed6a4-a2da-4c58-8b4f-689b07034fba-c000.snappy.parquet,part-00007-df2ed6a4-a2da-4c58-8b4f-689b07034fba-c000.snappy.parquet,14603704
dbfs:/user/hive/warehouse/test_data/part-00008-594e4db7-1736-4fa2-b0ec-e367d66b3e8c-c000.snappy.parquet,part-00008-594e4db7-1736-4fa2-b0ec-e367d66b3e8c-c000.snappy.parquet,13670148


In [0]:
%sql

drop table if exists train_data;

CREATE TABLE train_data
USING parquet
OPTIONS (
       path "/user/hive/warehouse/train_data/" );

drop table if exists test_data;

CREATE TABLE test_data
USING parquet
OPTIONS (
       path "/user/hive/warehouse/test_data/" ); 

In [0]:
spark.catalog.listTables()

Out[55]: [Table(name='order_priors_prods', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='prd_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='test_data', database='default', description=None, tableType='EXTERNAL', isTemporary=False),
 Table(name='train_data', database='default', description=None, tableType='EXTERNAL', isTemporary=False),
 Table(name='up_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='user_mart', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='user_mart_01', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='aisles', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='depts', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='orders', database=None, description=None, tab

In [0]:
spark.sql("set spark.databricks.delta.formatCheck.enabled=false")

Out[56]: DataFrame[key: string, value: string]

In [0]:
# train_data와 test_data 테이블을 DataFrame으로 변환. 
train_sdf = spark.sql("select * from train_data")
test_sdf = spark.sql("select * from test_data")
print('train_sdf type:', type(train_sdf))
print('test_sdf type:', type(test_sdf))

train_sdf type: <class 'pyspark.sql.dataframe.DataFrame'>
test_sdf type: <class 'pyspark.sql.dataframe.DataFrame'>


In [0]:
train_sdf.printSchema()

root
 |-- up_cnt: long (nullable = true)
 |-- up_reord_cnt: long (nullable = true)
 |-- up_no_reord_cnt: long (nullable = true)
 |-- up_reoredered_avg: double (nullable = true)
 |-- up_max_ord_num: integer (nullable = true)
 |-- up_min_ord_num: integer (nullable = true)
 |-- up_avg_cart: double (nullable = true)
 |-- up_avg_prior_days: double (nullable = true)
 |-- up_max_prior_days: double (nullable = true)
 |-- up_min_prior_days: double (nullable = true)
 |-- up_avg_ord_dow: double (nullable = true)
 |-- up_avg_ord_hour: double (nullable = true)
 |-- up_usr_ratio: double (nullable = true)
 |-- up_usr_reord_ratio: double (nullable = true)
 |-- up_usr_ord_num_diff: integer (nullable = true)
 |-- usr_total_cnt: long (nullable = true)
 |-- prd_uq_cnt: long (nullable = true)
 |-- order_uq_cnt: long (nullable = true)
 |-- usr_avg_prd_cnt: double (nullable = true)
 |-- usr_avg_uq_prd_cnt: double (nullable = true)
 |-- usr_uq_prd_ratio: double (nullable = true)
 |-- usr_reord_cnt: long (null

In [0]:
# 컬럼별로 Null 인 경우만 count하는 select 로직.
import pyspark.sql.functions as F
# up_avg_prior_days, up_max_prior_days, up_min_prior_days이 각각 552218건, up_max_prior_days 552218건, up_usr_reord_ratio가 30912건,  
# prd_avg_prior_days가 29건, prd_max_prior_days가 29 건이 null임. up_usr_reord_ratio는 사용자별 재주문도 0건, 사용자 상품별 재주문도 0건이어 무한대가 Null로 처리됨. 
# reoredered는 7654837 건이 null임. 
display(train_sdf.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in train_sdf.columns]))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,reordered
0,0,0,0,0,0,0,552218,552218,552218,0,0,0,30912,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7645837


In [0]:
train_sdf = train_sdf.fillna(0)

In [0]:
# feature vectorization 적용할 column명 추출. label 컬럼인 reordered는 제외
vector_columns = [column_name for column_name, column_type in train_sdf.dtypes if column_name != 'reordered']
print(vector_columns)

['up_cnt', 'up_reord_cnt', 'up_no_reord_cnt', 'up_reoredered_avg', 'up_max_ord_num', 'up_min_ord_num', 'up_avg_cart', 'up_avg_prior_days', 'up_max_prior_days', 'up_min_prior_days', 'up_avg_ord_dow', 'up_avg_ord_hour', 'up_usr_ratio', 'up_usr_reord_ratio', 'up_usr_ord_num_diff', 'usr_total_cnt', 'prd_uq_cnt', 'order_uq_cnt', 'usr_avg_prd_cnt', 'usr_avg_uq_prd_cnt', 'usr_uq_prd_ratio', 'usr_reord_cnt', 'usr_no_reord_cnt', 'usr_reordered_avg', 'usr_avg_prior_days', 'usr_max_prior_days', 'usr_min_prior_days', 'usr_avg_order_dow', 'usr_avg_order_hour_of_day', 'usr_max_order_number', 'days_since_prior_order', 'prd_reordered_cnt', 'prd_no_reordered_cnt', 'prd_avg_reordered', 'prd_unq_usr_cnt', 'prd_total_cnt', 'prd_usr_ratio', 'prd_avg_prior_days', 'prd_min_prior_days', 'prd_max_prior_days', 'aisle_distinct_usr_cnt', 'aisle_total_cnt', 'aisle_usr_ratio', 'usr_ratio_diff']


In [0]:
# feature vector화 적용 
from pyspark.ml.feature import VectorAssembler

vector_assembler = VectorAssembler(inputCols=vector_columns, outputCol='features')
train_sdf_vectorized = vector_assembler.transform(train_sdf)

display(train_sdf_vectorized.limit(10))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,reordered,features
8,7,1,0.875,20,1,6.75,13.428571428571429,30.0,4.0,2.125,12.375,0.0388349514563106,0.0507246376811594,0,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,11884,7461,0.6143189454639442,7461,19345,0.3856810545360558,12.253245299910477,0.0,30.0,85357,575881,0.1482198579220359,0.2374611966140198,0,"Map(vectorType -> dense, length -> 44, values -> List(8.0, 7.0, 1.0, 0.875, 20.0, 1.0, 6.75, 13.428571428571429, 30.0, 4.0, 2.125, 12.375, 0.038834951456310676, 0.050724637681159424, 0.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 11884.0, 7461.0, 0.6143189454639442, 7461.0, 19345.0, 0.38568105453605583, 12.253245299910475, 0.0, 30.0, 85357.0, 575881.0, 0.14821985792203599, 0.23746119661401985))"
8,7,1,0.875,20,2,9.625,15.375,30.0,5.0,1.625,13.5,0.0388349514563106,0.0507246376811594,0,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,7556,4567,0.6232780664851935,4567,12123,0.3767219335148065,11.31117903930131,0.0,30.0,25372,70887,0.357921762805592,0.0188001707092145,1,"Map(vectorType -> dense, length -> 44, values -> List(8.0, 7.0, 1.0, 0.875, 20.0, 2.0, 9.625, 15.375, 30.0, 5.0, 1.625, 13.5, 0.038834951456310676, 0.050724637681159424, 0.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 7556.0, 4567.0, 0.6232780664851935, 4567.0, 12123.0, 0.37672193351480654, 11.31117903930131, 0.0, 30.0, 25372.0, 70887.0, 0.357921762805592, 0.018800170709214525))"
1,0,1,0.0,18,18,7.0,7.0,7.0,7.0,0.0,9.0,0.0048543689320388,0.0,2,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,281,1220,0.1872085276482345,1220,1501,0.8127914723517655,9.720815752461322,0.0,30.0,86080,326692,0.263489770181088,0.5493017021706775,1,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 18.0, 18.0, 7.0, 7.0, 7.0, 7.0, 0.0, 9.0, 0.0048543689320388345, 0.0, 2.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 281.0, 1220.0, 0.1872085276482345, 1220.0, 1501.0, 0.8127914723517655, 9.720815752461322, 0.0, 30.0, 86080.0, 326692.0, 0.263489770181088, 0.5493017021706775))"
1,0,1,0.0,15,15,9.0,2.0,2.0,2.0,3.0,16.0,0.0048543689320388,0.0,5,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,14607,6824,0.6815827539545518,6824,21431,0.3184172460454482,9.726868370888118,0.0,30.0,177141,3642188,0.048635874919142,0.2697813711263062,0,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 15.0, 15.0, 9.0, 2.0, 2.0, 2.0, 3.0, 16.0, 0.0048543689320388345, 0.0, 5.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 14607.0, 6824.0, 0.6815827539545518, 6824.0, 21431.0, 0.3184172460454482, 9.726868370888118, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.2697813711263062))"
1,0,1,0.0,16,16,5.0,11.0,11.0,11.0,0.0,13.0,0.0048543689320388,0.0,4,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,50472,26424,0.6563670411985019,26424,76896,0.3436329588014981,11.644956456147568,0.0,30.0,159418,1765313,0.0903057984618025,0.2533271603396956,0,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 16.0, 16.0, 5.0, 11.0, 11.0, 11.0, 0.0, 13.0, 0.0048543689320388345, 0.0, 4.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 50472.0, 26424.0, 0.6563670411985019, 26424.0, 76896.0, 0.34363295880149813, 11.644956456147568, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.2533271603396956))"
5,4,1,0.8,11,1,3.0,12.25,30.0,4.0,2.0,13.8,0.0242718446601941,0.0289855072463768,9,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,89,53,0.6267605633802817,53,142,0.3732394366197183,14.02290076335878,1.0,30.0,124393,1452343,0.0856498774738474,0.2875895591458708,0,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 11.0, 1.0, 3.0, 12.25, 30.0, 4.0, 2.0, 13.8, 0.024271844660194174, 0.028985507246376812, 9.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 89.0, 53.0, 0.6267605633802817, 53.0, 142.0, 0.3732394366197183, 14.022900763358779, 1.0, 30.0, 124393.0, 1452343.0, 0.08564987747384743, 0.28758955914587087))"
1,0,1,0.0,3,3,2.0,30.0,30.0,30.0,0.0,18.0,0.0048543689320388,0.0,17,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,2518,2387,0.5133537206931702,2387,4905,0.4866462793068297,11.30755939524838,0.0,30.0,69429,270314,0.2568457423588863,0.2298005369479434,0,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 2.0, 30.0, 30.0, 30.0, 0.0, 18.0, 0.0048543689320388345, 0.0, 17.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 2518.0, 2387.0, 0.5133537206931702, 2387.0, 4905.0, 0.48664627930682974, 11.30755939524838, 0.0, 30.0, 69429.0, 270314.0, 0.25684574235888635, 0.2298005369479434))"
7,6,1,0.8571428571428571,17,2,9.285714285714286,13.428571428571429,30.0,3.0,1.5714285714285714,14.571428571428571,0.0339805825242718,0.0434782608695652,3,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,4890,1952,0.7147033031277404,1952,6842,0.2852966968722595,10.8748031496063,0.0,30.0,41929,175757,0.2385623332214364,0.0467343636508231,0,"Map(vectorType -> dense, length -> 44, values -> List(7.0, 6.0, 1.0, 0.8571428571428571, 17.0, 2.0, 9.285714285714286, 13.428571428571429, 30.0, 3.0, 1.5714285714285714, 14.571428571428571, 0.03398058252427184, 0.043478260869565216, 3.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 4890.0, 1952.0, 0.7147033031277404, 1952.0, 6842.0, 0.28529669687225956, 10.874803149606299, 0.0, 30.0, 41929.0, 175757.0, 0.2385623332214364, 0.046734363650823146))"
1,0,1,0.0,18,18,2.0,7.0,7.0,7.0,0.0,9.0,0.0048543689320388,0.0,2,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,4938,4461,0.5253750398978615,4461,9399,0.4746249601021385,10.403641360357266,0.0,30.0,159213,3418021,0.0465804627882625,0.4280444973138759,0,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 18.0, 18.0, 2.0, 7.0, 7.0, 7.0, 0.0, 9.0, 0.0048543689320388345, 0.0, 2.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 4938.0, 4461.0, 0.5253750398978615, 4461.0, 9399.0, 0.4746249601021385, 10.403641360357266, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.42804449731387595))"
3,2,1,0.6666666666666666,12,3,2.0,13.666666666666666,30.0,4.0,2.0,13.0,0.0145631067961165,0.0144927536231884,8,206,68,20,10.3,3.0294117647058822,0.3300970873786408,138,68,0.6699029126213593,13.54639175257732,30.0,2.0,1.7281553398058251,13.631067961165048,20,6.0,4427,2249,0.6631216297183943,2249,6676,0.3368783702816058,10.780603380192703,0.0,30.0,58749,390299,0.150523060525392,0.1863553097562137,0,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 12.0, 3.0, 2.0, 13.666666666666666, 30.0, 4.0, 2.0, 13.0, 0.014563106796116505, 0.014492753623188406, 8.0, 206.0, 68.0, 20.0, 10.3, 3.0294117647058822, 0.3300970873786408, 138.0, 68.0, 0.6699029126213593, 13.54639175257732, 30.0, 2.0, 1.7281553398058251, 13.631067961165048, 20.0, 6.0, 4427.0, 2249.0, 0.6631216297183943, 2249.0, 6676.0, 0.3368783702816058, 10.780603380192703, 0.0, 30.0, 58749.0, 390299.0, 0.15052306052539208, 0.1863553097562137))"


In [0]:
# 학습 데이터로 학습하여 Estimator Model 생성.. 7~8분 정도 걸림. 
from pyspark.ml.classification import RandomForestClassifier

rf_estimator = RandomForestClassifier(featuresCol='features', labelCol='reordered')
rf_model = rf_estimator.fit(train_sdf_vectorized)

In [0]:
print(type(rf_model))

<class 'pyspark.ml.classification.RandomForestClassificationModel'>


### 테스트 데이터 전처리 및 예측

In [0]:
test_sdf = spark.sql("select * from test_data")

In [0]:
test_sdf.printSchema()

root
 |-- user_id: integer (nullable = true)
 |-- product_id: long (nullable = true)
 |-- order_id: integer (nullable = true)
 |-- up_cnt: long (nullable = true)
 |-- up_reord_cnt: long (nullable = true)
 |-- up_no_reord_cnt: long (nullable = true)
 |-- up_reoredered_avg: double (nullable = true)
 |-- up_max_ord_num: integer (nullable = true)
 |-- up_min_ord_num: integer (nullable = true)
 |-- up_avg_cart: double (nullable = true)
 |-- up_avg_prior_days: double (nullable = true)
 |-- up_max_prior_days: double (nullable = true)
 |-- up_min_prior_days: double (nullable = true)
 |-- up_avg_ord_dow: double (nullable = true)
 |-- up_avg_ord_hour: double (nullable = true)
 |-- up_usr_ratio: double (nullable = true)
 |-- up_usr_reord_ratio: double (nullable = true)
 |-- up_usr_ord_num_diff: integer (nullable = true)
 |-- usr_total_cnt: long (nullable = true)
 |-- prd_uq_cnt: long (nullable = true)
 |-- order_uq_cnt: long (nullable = true)
 |-- usr_avg_prd_cnt: double (nullable = true)
 |-- us

In [0]:
test_sdf_id = test_sdf.select('user_id', 'product_id', 'order_id')
test_sdf = test_sdf.drop('user_id', 'product_id', 'order_id')
display(test_sdf.limit(10))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff
3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821
3,2,1,0.6666666666666666,6,3,2.6666666666666665,22.33333333333333,30.0,7.0,3.333333333333333,8.0,0.0638297872340425,0.0952380952380952,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,23609,19400,0.5489316189634728,19400,43009,0.4510683810365272,12.56283090774228,0.0,30.0,159213,3418021,0.0465804627882625,0.4044879182482647
1,0,1,0.0,1,1,6.0,,,,6.0,12.0,0.0212765957446808,0.0,5,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,3207,2671,0.5455937393671316,2671,5878,0.4544062606328683,13.046151039766508,0.0,30.0,60265,242996,0.2480081976658052,0.2063980629670631
5,4,1,0.8,6,1,3.0,18.25,30.0,7.0,4.4,8.8,0.1063829787234042,0.1904761904761904,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,7917,4524,0.6363636363636364,4524,12441,0.3636363636363636,13.058844194624978,0.0,30.0,92240,452134,0.2040103155259281,0.1596260481104355
1,0,1,0.0,6,6,9.0,30.0,30.0,30.0,4.0,0.0,0.0212765957446808,0.0,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,777,1156,0.4019658561821003,1156,1933,0.5980341438178997,12.62608695652174,0.0,30.0,60265,242996,0.2480081976658052,0.3500259461520945
1,0,1,0.0,5,5,6.0,30.0,30.0,30.0,0.0,12.0,0.0212765957446808,0.0,1,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,23593,15595,0.6020465448606717,15595,39188,0.3979534551393283,11.654457253427084,0.0,30.0,159213,3418021,0.0465804627882625,0.3513729923510658
2,1,1,0.5,6,5,6.5,30.0,30.0,30.0,2.0,6.0,0.0425531914893617,0.0476190476190476,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,13806,9582,0.590302719343253,9582,23388,0.409697280656747,10.543885937429051,0.0,30.0,159213,3418021,0.0465804627882625,0.3631168178684845
1,0,1,0.0,5,5,5.0,30.0,30.0,30.0,0.0,12.0,0.0212765957446808,0.0,1,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,13936,9193,0.6025336158069955,9193,23129,0.3974663841930044,11.050960696677878,0.0,30.0,159418,1765313,0.0903057984618025,0.3071605857312019
1,0,1,0.0,1,1,5.0,,,,6.0,12.0,0.0212765957446808,0.0,5,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2667,5321,0.3338758137205809,5321,7988,0.6661241862794192,12.752408477842003,0.0,30.0,80108,222049,0.3607672180464672,0.3053569682329519
2,1,1,0.5,6,5,6.5,30.0,30.0,30.0,2.0,6.0,0.0425531914893617,0.0476190476190476,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,33857,21627,0.6102119529954582,21627,55484,0.3897880470045418,10.99423382693818,0.0,30.0,159213,3418021,0.0465804627882625,0.3432075842162793


In [0]:
test_sdf = test_sdf.fillna(0)
test_sdf_vectorized = vector_assembler.transform(test_sdf)

In [0]:
predictions = rf_model.transform(test_sdf_vectorized)
display(predictions)

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction
3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0
3,2,1,0.6666666666666666,6,3,2.6666666666666665,22.33333333333333,30.0,7.0,3.333333333333333,8.0,0.0638297872340425,0.0952380952380952,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,23609,19400,0.5489316189634728,19400,43009,0.4510683810365272,12.56283090774228,0.0,30.0,159213,3418021,0.0465804627882625,0.4044879182482647,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 6.0, 3.0, 2.6666666666666665, 22.333333333333332, 30.0, 7.0, 3.3333333333333335, 8.0, 0.06382978723404255, 0.09523809523809523, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 23609.0, 19400.0, 0.5489316189634728, 19400.0, 43009.0, 0.45106838103652724, 12.562830907742281, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.4044879182482647))","Map(vectorType -> dense, length -> 2, values -> List(14.543791425706058, 5.456208574293942))","Map(vectorType -> dense, length -> 2, values -> List(0.7271895712853029, 0.27281042871469713))",0.0
1,0,1,0.0,1,1,6.0,0.0,0.0,0.0,6.0,12.0,0.0212765957446808,0.0,5,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,3207,2671,0.5455937393671316,2671,5878,0.4544062606328683,13.046151039766508,0.0,30.0,60265,242996,0.2480081976658052,0.2063980629670631,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 6.0, 0.0, 0.0, 0.0, 6.0, 12.0, 0.02127659574468085, 0.0, 5.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 3207.0, 2671.0, 0.5455937393671316, 2671.0, 5878.0, 0.4544062606328683, 13.046151039766508, 0.0, 30.0, 60265.0, 242996.0, 0.2480081976658052, 0.20639806296706312))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0
5,4,1,0.8,6,1,3.0,18.25,30.0,7.0,4.4,8.8,0.1063829787234042,0.1904761904761904,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,7917,4524,0.6363636363636364,4524,12441,0.3636363636363636,13.058844194624978,0.0,30.0,92240,452134,0.2040103155259281,0.1596260481104355,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 3.0, 18.25, 30.0, 7.0, 4.4, 8.8, 0.10638297872340426, 0.19047619047619047, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 7917.0, 4524.0, 0.6363636363636364, 4524.0, 12441.0, 0.36363636363636365, 13.058844194624978, 0.0, 30.0, 92240.0, 452134.0, 0.20401031552592816, 0.1596260481104355))","Map(vectorType -> dense, length -> 2, values -> List(13.938783775019786, 6.061216224980218))","Map(vectorType -> dense, length -> 2, values -> List(0.6969391887509891, 0.30306081124901085))",0.0
1,0,1,0.0,6,6,9.0,30.0,30.0,30.0,4.0,0.0,0.0212765957446808,0.0,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,777,1156,0.4019658561821003,1156,1933,0.5980341438178997,12.62608695652174,0.0,30.0,60265,242996,0.2480081976658052,0.3500259461520945,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 6.0, 6.0, 9.0, 30.0, 30.0, 30.0, 4.0, 0.0, 0.02127659574468085, 0.0, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 777.0, 1156.0, 0.40196585618210035, 1156.0, 1933.0, 0.5980341438178997, 12.626086956521739, 0.0, 30.0, 60265.0, 242996.0, 0.2480081976658052, 0.3500259461520945))","Map(vectorType -> dense, length -> 2, values -> List(18.260602012665807, 1.7393979873341965))","Map(vectorType -> dense, length -> 2, values -> List(0.9130301006332902, 0.0869698993667098))",0.0
1,0,1,0.0,5,5,6.0,30.0,30.0,30.0,0.0,12.0,0.0212765957446808,0.0,1,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,23593,15595,0.6020465448606717,15595,39188,0.3979534551393283,11.654457253427084,0.0,30.0,159213,3418021,0.0465804627882625,0.3513729923510658,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 5.0, 5.0, 6.0, 30.0, 30.0, 30.0, 0.0, 12.0, 0.02127659574468085, 0.0, 1.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 23593.0, 15595.0, 0.6020465448606717, 15595.0, 39188.0, 0.39795345513932834, 11.654457253427084, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3513729923510658))","Map(vectorType -> dense, length -> 2, values -> List(17.914093893167028, 2.0859061068329754))","Map(vectorType -> dense, length -> 2, values -> List(0.8957046946583512, 0.10429530534164876))",0.0
2,1,1,0.5,6,5,6.5,30.0,30.0,30.0,2.0,6.0,0.0425531914893617,0.0476190476190476,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,13806,9582,0.590302719343253,9582,23388,0.409697280656747,10.543885937429051,0.0,30.0,159213,3418021,0.0465804627882625,0.3631168178684845,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 6.0, 5.0, 6.5, 30.0, 30.0, 30.0, 2.0, 6.0, 0.0425531914893617, 0.047619047619047616, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 13806.0, 9582.0, 0.590302719343253, 9582.0, 23388.0, 0.40969728065674704, 10.543885937429051, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3631168178684845))","Map(vectorType -> dense, length -> 2, values -> List(16.91216329884636, 3.0878367011536376))","Map(vectorType -> dense, length -> 2, values -> List(0.845608164942318, 0.15439183505768192))",0.0
1,0,1,0.0,5,5,5.0,30.0,30.0,30.0,0.0,12.0,0.0212765957446808,0.0,1,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,13936,9193,0.6025336158069955,9193,23129,0.3974663841930044,11.050960696677878,0.0,30.0,159418,1765313,0.0903057984618025,0.3071605857312019,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0, 30.0, 30.0, 30.0, 0.0, 12.0, 0.02127659574468085, 0.0, 1.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 13936.0, 9193.0, 0.6025336158069955, 9193.0, 23129.0, 0.39746638419300445, 11.050960696677878, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.3071605857312019))","Map(vectorType -> dense, length -> 2, values -> List(18.250911168879476, 1.7490888311205262))","Map(vectorType -> dense, length -> 2, values -> List(0.9125455584439737, 0.08745444155602629))",0.0
1,0,1,0.0,1,1,5.0,0.0,0.0,0.0,6.0,12.0,0.0212765957446808,0.0,5,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2667,5321,0.3338758137205809,5321,7988,0.6661241862794192,12.752408477842003,0.0,30.0,80108,222049,0.3607672180464672,0.3053569682329519,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 5.0, 0.0, 0.0, 0.0, 6.0, 12.0, 0.02127659574468085, 0.0, 5.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2667.0, 5321.0, 0.3338758137205809, 5321.0, 7988.0, 0.6661241862794192, 12.752408477842003, 0.0, 30.0, 80108.0, 222049.0, 0.3607672180464672, 0.30535696823295194))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0
2,1,1,0.5,6,5,6.5,30.0,30.0,30.0,2.0,6.0,0.0425531914893617,0.0476190476190476,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,33857,21627,0.6102119529954582,21627,55484,0.3897880470045418,10.99423382693818,0.0,30.0,159213,3418021,0.0465804627882625,0.3432075842162793,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 6.0, 5.0, 6.5, 30.0, 30.0, 30.0, 2.0, 6.0, 0.0425531914893617, 0.047619047619047616, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 33857.0, 21627.0, 0.6102119529954582, 21627.0, 55484.0, 0.38978804700454184, 10.994233826938181, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3432075842162793))","Map(vectorType -> dense, length -> 2, values -> List(16.37882603311174, 3.621173966888258))","Map(vectorType -> dense, length -> 2, values -> List(0.8189413016555871, 0.18105869834441293))",0.0


In [0]:
predictions.groupBy('prediction').count().show()

+----------+-------+
|prediction|  count|
+----------+-------+
|       0.0|4828184|
|       1.0|   5105|
+----------+-------+



In [0]:
display(test_sdf_id.limit(10))

user_id,product_id,order_id,row_id
134,36431,831748,0
134,44142,831748,1
134,32650,831748,2
134,5782,831748,3
134,30638,831748,4
134,15290,831748,5
134,43768,831748,6
134,11182,831748,7
134,44234,831748,8
134,21938,831748,9


### 예측 결과를 kaggle submission format으로 변경. 
* kaggle에서 예측 성능을 평가 받기 위해서는 아래와 같이 테스트 데이터의 개별 order_id에 재주문 예측 상품코드를 공백으로 분리하여 연속해서 이어주는 형태로 생성
* train과 test 데이터 세트에서 하나의 user_id는 하나의 order_id를 가짐. 만일 특정 order_id(즉 개별 user의 하나의 주문)에 있는 모든 상품들이 다 첫주문이면 None으로 생성.  
order_id,products  
17,1 2  
34,None  
137,1 2 3  
etc.
* 예측 결과에 order_id, product_id, user_id를 붙여 넣기.
* 예측 확률을 기반으로 예측 reordered 재 결정. reordered 1로 예측 확률이 0.21 보다 크면 reordered 1로 재 설정. 
* 예측 결과의 order_id별로 group by하여 product_id결과를 collect_list()와 udf를 이용하여 연속해서 이어줌

In [0]:
#예측 결과 predictions와 test_sdf_id를 조인하여 붙이기.  
from pyspark.sql.functions import monotonically_increasing_id
# row건수별로 0부터 순차적으로 증가하는 row_id 컬럼을 monotonically_increasing_id()을 이용하여 생성. 
test_sdf_id = test_sdf_id.withColumn("row_id", monotonically_increasing_id())
predictions = predictions.withColumn("row_id", monotonically_increasing_id())

display(predictions.limit(10))

up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,row_id
3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0,0
3,2,1,0.6666666666666666,6,3,2.6666666666666665,22.33333333333333,30.0,7.0,3.333333333333333,8.0,0.0638297872340425,0.0952380952380952,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,23609,19400,0.5489316189634728,19400,43009,0.4510683810365272,12.56283090774228,0.0,30.0,159213,3418021,0.0465804627882625,0.4044879182482647,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 6.0, 3.0, 2.6666666666666665, 22.333333333333332, 30.0, 7.0, 3.3333333333333335, 8.0, 0.06382978723404255, 0.09523809523809523, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 23609.0, 19400.0, 0.5489316189634728, 19400.0, 43009.0, 0.45106838103652724, 12.562830907742281, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.4044879182482647))","Map(vectorType -> dense, length -> 2, values -> List(14.543791425706058, 5.456208574293942))","Map(vectorType -> dense, length -> 2, values -> List(0.7271895712853029, 0.27281042871469713))",0.0,1
1,0,1,0.0,1,1,6.0,0.0,0.0,0.0,6.0,12.0,0.0212765957446808,0.0,5,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,3207,2671,0.5455937393671316,2671,5878,0.4544062606328683,13.046151039766508,0.0,30.0,60265,242996,0.2480081976658052,0.2063980629670631,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 6.0, 0.0, 0.0, 0.0, 6.0, 12.0, 0.02127659574468085, 0.0, 5.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 3207.0, 2671.0, 0.5455937393671316, 2671.0, 5878.0, 0.4544062606328683, 13.046151039766508, 0.0, 30.0, 60265.0, 242996.0, 0.2480081976658052, 0.20639806296706312))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0,2
5,4,1,0.8,6,1,3.0,18.25,30.0,7.0,4.4,8.8,0.1063829787234042,0.1904761904761904,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,7917,4524,0.6363636363636364,4524,12441,0.3636363636363636,13.058844194624978,0.0,30.0,92240,452134,0.2040103155259281,0.1596260481104355,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 3.0, 18.25, 30.0, 7.0, 4.4, 8.8, 0.10638297872340426, 0.19047619047619047, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 7917.0, 4524.0, 0.6363636363636364, 4524.0, 12441.0, 0.36363636363636365, 13.058844194624978, 0.0, 30.0, 92240.0, 452134.0, 0.20401031552592816, 0.1596260481104355))","Map(vectorType -> dense, length -> 2, values -> List(13.938783775019786, 6.061216224980218))","Map(vectorType -> dense, length -> 2, values -> List(0.6969391887509891, 0.30306081124901085))",0.0,3
1,0,1,0.0,6,6,9.0,30.0,30.0,30.0,4.0,0.0,0.0212765957446808,0.0,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,777,1156,0.4019658561821003,1156,1933,0.5980341438178997,12.62608695652174,0.0,30.0,60265,242996,0.2480081976658052,0.3500259461520945,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 6.0, 6.0, 9.0, 30.0, 30.0, 30.0, 4.0, 0.0, 0.02127659574468085, 0.0, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 777.0, 1156.0, 0.40196585618210035, 1156.0, 1933.0, 0.5980341438178997, 12.626086956521739, 0.0, 30.0, 60265.0, 242996.0, 0.2480081976658052, 0.3500259461520945))","Map(vectorType -> dense, length -> 2, values -> List(18.260602012665807, 1.7393979873341965))","Map(vectorType -> dense, length -> 2, values -> List(0.9130301006332902, 0.0869698993667098))",0.0,4
1,0,1,0.0,5,5,6.0,30.0,30.0,30.0,0.0,12.0,0.0212765957446808,0.0,1,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,23593,15595,0.6020465448606717,15595,39188,0.3979534551393283,11.654457253427084,0.0,30.0,159213,3418021,0.0465804627882625,0.3513729923510658,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 5.0, 5.0, 6.0, 30.0, 30.0, 30.0, 0.0, 12.0, 0.02127659574468085, 0.0, 1.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 23593.0, 15595.0, 0.6020465448606717, 15595.0, 39188.0, 0.39795345513932834, 11.654457253427084, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3513729923510658))","Map(vectorType -> dense, length -> 2, values -> List(17.914093893167028, 2.0859061068329754))","Map(vectorType -> dense, length -> 2, values -> List(0.8957046946583512, 0.10429530534164876))",0.0,5
2,1,1,0.5,6,5,6.5,30.0,30.0,30.0,2.0,6.0,0.0425531914893617,0.0476190476190476,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,13806,9582,0.590302719343253,9582,23388,0.409697280656747,10.543885937429051,0.0,30.0,159213,3418021,0.0465804627882625,0.3631168178684845,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 6.0, 5.0, 6.5, 30.0, 30.0, 30.0, 2.0, 6.0, 0.0425531914893617, 0.047619047619047616, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 13806.0, 9582.0, 0.590302719343253, 9582.0, 23388.0, 0.40969728065674704, 10.543885937429051, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3631168178684845))","Map(vectorType -> dense, length -> 2, values -> List(16.91216329884636, 3.0878367011536376))","Map(vectorType -> dense, length -> 2, values -> List(0.845608164942318, 0.15439183505768192))",0.0,6
1,0,1,0.0,5,5,5.0,30.0,30.0,30.0,0.0,12.0,0.0212765957446808,0.0,1,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,13936,9193,0.6025336158069955,9193,23129,0.3974663841930044,11.050960696677878,0.0,30.0,159418,1765313,0.0903057984618025,0.3071605857312019,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0, 30.0, 30.0, 30.0, 0.0, 12.0, 0.02127659574468085, 0.0, 1.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 13936.0, 9193.0, 0.6025336158069955, 9193.0, 23129.0, 0.39746638419300445, 11.050960696677878, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.3071605857312019))","Map(vectorType -> dense, length -> 2, values -> List(18.250911168879476, 1.7490888311205262))","Map(vectorType -> dense, length -> 2, values -> List(0.9125455584439737, 0.08745444155602629))",0.0,7
1,0,1,0.0,1,1,5.0,0.0,0.0,0.0,6.0,12.0,0.0212765957446808,0.0,5,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2667,5321,0.3338758137205809,5321,7988,0.6661241862794192,12.752408477842003,0.0,30.0,80108,222049,0.3607672180464672,0.3053569682329519,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 5.0, 0.0, 0.0, 0.0, 6.0, 12.0, 0.02127659574468085, 0.0, 5.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2667.0, 5321.0, 0.3338758137205809, 5321.0, 7988.0, 0.6661241862794192, 12.752408477842003, 0.0, 30.0, 80108.0, 222049.0, 0.3607672180464672, 0.30535696823295194))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0,8
2,1,1,0.5,6,5,6.5,30.0,30.0,30.0,2.0,6.0,0.0425531914893617,0.0476190476190476,0,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,33857,21627,0.6102119529954582,21627,55484,0.3897880470045418,10.99423382693818,0.0,30.0,159213,3418021,0.0465804627882625,0.3432075842162793,"Map(vectorType -> dense, length -> 44, values -> List(2.0, 1.0, 1.0, 0.5, 6.0, 5.0, 6.5, 30.0, 30.0, 30.0, 2.0, 6.0, 0.0425531914893617, 0.047619047619047616, 0.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 33857.0, 21627.0, 0.6102119529954582, 21627.0, 55484.0, 0.38978804700454184, 10.994233826938181, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.3432075842162793))","Map(vectorType -> dense, length -> 2, values -> List(16.37882603311174, 3.621173966888258))","Map(vectorType -> dense, length -> 2, values -> List(0.8189413016555871, 0.18105869834441293))",0.0,9


In [0]:
# order_id와 product_id를 얻기 위해 test_sdf_id와 predictions을 row_id로 조인 시킴. 
predictions = test_sdf_id.join(predictions, ("row_id")).drop("row_id")
print(test_sdf.count(), predictions.count())
display(predictions.limit(10))

4833289 4833289


user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction
134,36431,831748,3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0
134,16953,831748,1,0,1,0.0,3,3,6.0,7.0,7.0,7.0,6.0,12.0,0.0212765957446808,0.0,3,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,10503,9530,0.5242849298657215,9530,20033,0.4757150701342784,12.2864745726266,0.0,30.0,81973,289400,0.2832515549412577,0.1924635151930206,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 6.0, 7.0, 7.0, 7.0, 6.0, 12.0, 0.02127659574468085, 0.0, 3.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 10503.0, 9530.0, 0.5242849298657215, 9530.0, 20033.0, 0.47571507013427844, 12.286474572626599, 0.0, 30.0, 81973.0, 289400.0, 0.28325155494125775, 0.19246351519302068))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0
153,14992,1658650,10,9,1,0.9,16,2,4.3,12.2,28.0,6.0,1.7,16.2,0.0374531835205992,0.0508474576271186,3,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,16942,12127,0.5828201864529223,12127,29069,0.4171798135470776,11.64868977804172,0.0,30.0,159213,3418021,0.0465804627882625,0.370599350758815,"Map(vectorType -> dense, length -> 44, values -> List(10.0, 9.0, 1.0, 0.9, 16.0, 2.0, 4.3, 12.2, 28.0, 6.0, 1.7, 16.2, 0.03745318352059925, 0.05084745762711865, 3.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 16942.0, 12127.0, 0.5828201864529223, 12127.0, 29069.0, 0.4171798135470776, 11.64868977804172, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.37059935075881506))","Map(vectorType -> dense, length -> 2, values -> List(15.340715909550475, 4.659284090449524))","Map(vectorType -> dense, length -> 2, values -> List(0.7670357954775238, 0.23296420452247618))",0.0
153,33653,1658650,1,0,1,0.0,10,10,6.0,7.0,7.0,7.0,1.0,12.0,0.0037453183520599,0.0,9,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,87,52,0.6258992805755396,52,139,0.3741007194244604,9.7109375,0.0,30.0,73124,266637,0.2742455098129667,0.0998552096114937,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 10.0, 10.0, 6.0, 7.0, 7.0, 7.0, 1.0, 12.0, 0.003745318352059925, 0.0, 9.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 87.0, 52.0, 0.6258992805755396, 52.0, 139.0, 0.37410071942446044, 9.7109375, 0.0, 30.0, 73124.0, 266637.0, 0.2742455098129667, 0.09985520961149374))","Map(vectorType -> dense, length -> 2, values -> List(18.853207545862247, 1.1467924541377548))","Map(vectorType -> dense, length -> 2, values -> List(0.9426603772931121, 0.05733962270688773))",0.0
153,21903,1658650,7,6,1,0.8571428571428571,18,9,5.285714285714286,12.857142857142858,28.0,6.0,1.0,16.0,0.0262172284644194,0.0338983050847457,1,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(7.0, 6.0, 1.0, 0.8571428571428571, 18.0, 9.0, 5.285714285714286, 12.857142857142858, 28.0, 6.0, 1.0, 16.0, 0.026217228464419477, 0.03389830508474576, 1.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(14.838104547828625, 5.161895452171375))","Map(vectorType -> dense, length -> 2, values -> List(0.7419052273914313, 0.2580947726085688))",0.0
153,38159,1658650,4,3,1,0.75,19,1,5.75,9.333333333333334,12.0,7.0,1.25,14.75,0.0149812734082397,0.0169491525423728,0,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,12789,8512,0.6003943476832073,8512,21301,0.3996056523167926,9.67922380080194,0.0,30.0,177141,3642188,0.048635874919142,0.3509697773976506,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 19.0, 1.0, 5.75, 9.333333333333334, 12.0, 7.0, 1.25, 14.75, 0.0149812734082397, 0.01694915254237288, 0.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 12789.0, 8512.0, 0.6003943476832073, 8512.0, 21301.0, 0.3996056523167926, 9.67922380080194, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.3509697773976506))","Map(vectorType -> dense, length -> 2, values -> List(17.320692924288814, 2.679307075711183))","Map(vectorType -> dense, length -> 2, values -> List(0.8660346462144408, 0.13396535378555918))",0.0
180,30480,2769561,1,0,1,0.0,4,4,32.0,28.0,28.0,28.0,6.0,10.0,0.0045045045045045,0.0,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,3062,3151,0.492837598583615,3151,6213,0.507162401416385,11.187958383080334,0.0,30.0,76476,297037,0.2574628749953709,0.249699526421014,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 32.0, 28.0, 28.0, 28.0, 6.0, 10.0, 0.0045045045045045045, 0.0, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 3062.0, 3151.0, 0.492837598583615, 3151.0, 6213.0, 0.507162401416385, 11.187958383080334, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24969952642101406))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0
180,22963,2769561,5,4,1,0.8,6,1,4.8,26.0,30.0,18.0,5.2,10.0,0.0225225225225225,0.0373831775700934,0,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,18893,9491,0.6656214768883878,9491,28384,0.3343785231116122,10.781790169848188,0.0,30.0,78030,395130,0.1974793106066358,0.1368992125049764,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 4.8, 26.0, 30.0, 18.0, 5.2, 10.0, 0.02252252252252252, 0.037383177570093455, 0.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 18893.0, 9491.0, 0.6656214768883878, 9491.0, 28384.0, 0.3343785231116122, 10.781790169848188, 0.0, 30.0, 78030.0, 395130.0, 0.1974793106066358, 0.13689921250497641))","Map(vectorType -> dense, length -> 2, values -> List(14.360743568272248, 5.6392564317277545))","Map(vectorType -> dense, length -> 2, values -> List(0.7180371784136123, 0.2819628215863877))",0.0
180,3376,2769561,4,3,1,0.75,4,1,29.25,28.666666666666668,30.0,28.0,5.0,10.25,0.018018018018018,0.02803738317757,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,6281,6457,0.4930915371329879,6457,12738,0.5069084628670121,11.447830101569714,0.0,30.0,76476,297037,0.2574628749953709,0.2494455878716411,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 4.0, 1.0, 29.25, 28.666666666666668, 30.0, 28.0, 5.0, 10.25, 0.018018018018018018, 0.028037383177570093, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 6281.0, 6457.0, 0.4930915371329879, 6457.0, 12738.0, 0.5069084628670121, 11.447830101569714, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24944558787164112))","Map(vectorType -> dense, length -> 2, values -> List(17.033499090494846, 2.9665009095051533))","Map(vectorType -> dense, length -> 2, values -> List(0.8516749545247423, 0.14832504547525766))",0.0
186,24787,470997,1,0,1,0.0,2,2,5.0,18.0,18.0,18.0,3.0,17.0,0.0185185185185185,0.0,5,54,38,7,7.714285714285714,1.4210526315789471,0.7037037037037037,16,38,0.2962962962962963,15.959183673469388,30.0,3.0,3.1481481481481484,12.11111111111111,7,5.0,209,269,0.4372384937238494,269,478,0.5627615062761506,12.713302752293576,0.0,30.0,6631,15901,0.417017797622791,0.1457437086533596,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 2.0, 2.0, 5.0, 18.0, 18.0, 18.0, 3.0, 17.0, 0.018518518518518517, 0.0, 5.0, 54.0, 38.0, 7.0, 7.714285714285714, 1.4210526315789473, 0.7037037037037037, 16.0, 38.0, 0.2962962962962963, 15.959183673469388, 30.0, 3.0, 3.1481481481481484, 12.11111111111111, 7.0, 5.0, 209.0, 269.0, 0.4372384937238494, 269.0, 478.0, 0.5627615062761506, 12.713302752293577, 0.0, 30.0, 6631.0, 15901.0, 0.417017797622791, 0.14574370865335962))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0


In [0]:
predictions.printSchema()

root
 |-- user_id: integer (nullable = true)
 |-- product_id: long (nullable = true)
 |-- order_id: integer (nullable = true)
 |-- up_cnt: long (nullable = true)
 |-- up_reord_cnt: long (nullable = true)
 |-- up_no_reord_cnt: long (nullable = true)
 |-- up_reoredered_avg: double (nullable = false)
 |-- up_max_ord_num: integer (nullable = true)
 |-- up_min_ord_num: integer (nullable = true)
 |-- up_avg_cart: double (nullable = false)
 |-- up_avg_prior_days: double (nullable = false)
 |-- up_max_prior_days: double (nullable = false)
 |-- up_min_prior_days: double (nullable = false)
 |-- up_avg_ord_dow: double (nullable = false)
 |-- up_avg_ord_hour: double (nullable = false)
 |-- up_usr_ratio: double (nullable = false)
 |-- up_usr_reord_ratio: double (nullable = false)
 |-- up_usr_ord_num_diff: integer (nullable = true)
 |-- usr_total_cnt: long (nullable = true)
 |-- prd_uq_cnt: long (nullable = true)
 |-- order_uq_cnt: long (nullable = true)
 |-- usr_avg_prd_cnt: double (nullable = fals

In [0]:
# 여러 값으로 구성된 vector 컬럼에서 특정 값만 추출. probability 컬럼은 0/1 일때의 확률을 모두 가짐. 이중 1일 때(즉 재주문)의 확률을 추출
# 먼저 vector를 array로 변환
from pyspark.ml.functions import vector_to_array
predictions = predictions.withColumn("probability_arr", vector_to_array('probability'))
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr
134,36431,831748,3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0,"List(0.8215437730181607, 0.17845622698183922)"
134,16953,831748,1,0,1,0.0,3,3,6.0,7.0,7.0,7.0,6.0,12.0,0.0212765957446808,0.0,3,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,10503,9530,0.5242849298657215,9530,20033,0.4757150701342784,12.2864745726266,0.0,30.0,81973,289400,0.2832515549412577,0.1924635151930206,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 6.0, 7.0, 7.0, 7.0, 6.0, 12.0, 0.02127659574468085, 0.0, 3.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 10503.0, 9530.0, 0.5242849298657215, 9530.0, 20033.0, 0.47571507013427844, 12.286474572626599, 0.0, 30.0, 81973.0, 289400.0, 0.28325155494125775, 0.19246351519302068))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0,"List(0.9372644078615597, 0.06273559213844024)"
153,14992,1658650,10,9,1,0.9,16,2,4.3,12.2,28.0,6.0,1.7,16.2,0.0374531835205992,0.0508474576271186,3,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,16942,12127,0.5828201864529223,12127,29069,0.4171798135470776,11.64868977804172,0.0,30.0,159213,3418021,0.0465804627882625,0.370599350758815,"Map(vectorType -> dense, length -> 44, values -> List(10.0, 9.0, 1.0, 0.9, 16.0, 2.0, 4.3, 12.2, 28.0, 6.0, 1.7, 16.2, 0.03745318352059925, 0.05084745762711865, 3.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 16942.0, 12127.0, 0.5828201864529223, 12127.0, 29069.0, 0.4171798135470776, 11.64868977804172, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.37059935075881506))","Map(vectorType -> dense, length -> 2, values -> List(15.340715909550475, 4.659284090449524))","Map(vectorType -> dense, length -> 2, values -> List(0.7670357954775238, 0.23296420452247618))",0.0,"List(0.7670357954775238, 0.23296420452247618)"
153,33653,1658650,1,0,1,0.0,10,10,6.0,7.0,7.0,7.0,1.0,12.0,0.0037453183520599,0.0,9,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,87,52,0.6258992805755396,52,139,0.3741007194244604,9.7109375,0.0,30.0,73124,266637,0.2742455098129667,0.0998552096114937,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 10.0, 10.0, 6.0, 7.0, 7.0, 7.0, 1.0, 12.0, 0.003745318352059925, 0.0, 9.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 87.0, 52.0, 0.6258992805755396, 52.0, 139.0, 0.37410071942446044, 9.7109375, 0.0, 30.0, 73124.0, 266637.0, 0.2742455098129667, 0.09985520961149374))","Map(vectorType -> dense, length -> 2, values -> List(18.853207545862247, 1.1467924541377548))","Map(vectorType -> dense, length -> 2, values -> List(0.9426603772931121, 0.05733962270688773))",0.0,"List(0.9426603772931121, 0.05733962270688773)"
153,21903,1658650,7,6,1,0.8571428571428571,18,9,5.285714285714286,12.857142857142858,28.0,6.0,1.0,16.0,0.0262172284644194,0.0338983050847457,1,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(7.0, 6.0, 1.0, 0.8571428571428571, 18.0, 9.0, 5.285714285714286, 12.857142857142858, 28.0, 6.0, 1.0, 16.0, 0.026217228464419477, 0.03389830508474576, 1.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(14.838104547828625, 5.161895452171375))","Map(vectorType -> dense, length -> 2, values -> List(0.7419052273914313, 0.2580947726085688))",0.0,"List(0.7419052273914313, 0.2580947726085688)"
153,38159,1658650,4,3,1,0.75,19,1,5.75,9.333333333333334,12.0,7.0,1.25,14.75,0.0149812734082397,0.0169491525423728,0,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,12789,8512,0.6003943476832073,8512,21301,0.3996056523167926,9.67922380080194,0.0,30.0,177141,3642188,0.048635874919142,0.3509697773976506,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 19.0, 1.0, 5.75, 9.333333333333334, 12.0, 7.0, 1.25, 14.75, 0.0149812734082397, 0.01694915254237288, 0.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 12789.0, 8512.0, 0.6003943476832073, 8512.0, 21301.0, 0.3996056523167926, 9.67922380080194, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.3509697773976506))","Map(vectorType -> dense, length -> 2, values -> List(17.320692924288814, 2.679307075711183))","Map(vectorType -> dense, length -> 2, values -> List(0.8660346462144408, 0.13396535378555918))",0.0,"List(0.8660346462144408, 0.13396535378555918)"
180,30480,2769561,1,0,1,0.0,4,4,32.0,28.0,28.0,28.0,6.0,10.0,0.0045045045045045,0.0,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,3062,3151,0.492837598583615,3151,6213,0.507162401416385,11.187958383080334,0.0,30.0,76476,297037,0.2574628749953709,0.249699526421014,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 32.0, 28.0, 28.0, 28.0, 6.0, 10.0, 0.0045045045045045045, 0.0, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 3062.0, 3151.0, 0.492837598583615, 3151.0, 6213.0, 0.507162401416385, 11.187958383080334, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24969952642101406))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)"
180,22963,2769561,5,4,1,0.8,6,1,4.8,26.0,30.0,18.0,5.2,10.0,0.0225225225225225,0.0373831775700934,0,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,18893,9491,0.6656214768883878,9491,28384,0.3343785231116122,10.781790169848188,0.0,30.0,78030,395130,0.1974793106066358,0.1368992125049764,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 4.8, 26.0, 30.0, 18.0, 5.2, 10.0, 0.02252252252252252, 0.037383177570093455, 0.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 18893.0, 9491.0, 0.6656214768883878, 9491.0, 28384.0, 0.3343785231116122, 10.781790169848188, 0.0, 30.0, 78030.0, 395130.0, 0.1974793106066358, 0.13689921250497641))","Map(vectorType -> dense, length -> 2, values -> List(14.360743568272248, 5.6392564317277545))","Map(vectorType -> dense, length -> 2, values -> List(0.7180371784136123, 0.2819628215863877))",0.0,"List(0.7180371784136123, 0.2819628215863877)"
180,3376,2769561,4,3,1,0.75,4,1,29.25,28.666666666666668,30.0,28.0,5.0,10.25,0.018018018018018,0.02803738317757,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,6281,6457,0.4930915371329879,6457,12738,0.5069084628670121,11.447830101569714,0.0,30.0,76476,297037,0.2574628749953709,0.2494455878716411,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 4.0, 1.0, 29.25, 28.666666666666668, 30.0, 28.0, 5.0, 10.25, 0.018018018018018018, 0.028037383177570093, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 6281.0, 6457.0, 0.4930915371329879, 6457.0, 12738.0, 0.5069084628670121, 11.447830101569714, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24944558787164112))","Map(vectorType -> dense, length -> 2, values -> List(17.033499090494846, 2.9665009095051533))","Map(vectorType -> dense, length -> 2, values -> List(0.8516749545247423, 0.14832504547525766))",0.0,"List(0.8516749545247423, 0.14832504547525766)"
186,24787,470997,1,0,1,0.0,2,2,5.0,18.0,18.0,18.0,3.0,17.0,0.0185185185185185,0.0,5,54,38,7,7.714285714285714,1.4210526315789471,0.7037037037037037,16,38,0.2962962962962963,15.959183673469388,30.0,3.0,3.1481481481481484,12.11111111111111,7,5.0,209,269,0.4372384937238494,269,478,0.5627615062761506,12.713302752293576,0.0,30.0,6631,15901,0.417017797622791,0.1457437086533596,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 2.0, 2.0, 5.0, 18.0, 18.0, 18.0, 3.0, 17.0, 0.018518518518518517, 0.0, 5.0, 54.0, 38.0, 7.0, 7.714285714285714, 1.4210526315789473, 0.7037037037037037, 16.0, 38.0, 0.2962962962962963, 15.959183673469388, 30.0, 3.0, 3.1481481481481484, 12.11111111111111, 7.0, 5.0, 209.0, 269.0, 0.4372384937238494, 269.0, 478.0, 0.5627615062761506, 12.713302752293577, 0.0, 30.0, 6631.0, 15901.0, 0.417017797622791, 0.14574370865335962))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)"


In [0]:
predictions.select(F.col('probability_arr')[1]).show(10)

+-------------------+
| probability_arr[1]|
+-------------------+
|0.17845622698183922|
|0.27281042871469713|
|0.06273559213844024|
|0.30306081124901085|
| 0.0869698993667098|
|0.10429530534164876|
|0.15439183505768192|
|0.08745444155602629|
|0.06273559213844024|
|0.18105869834441293|
+-------------------+
only showing top 10 rows



In [0]:
# 변환된 array에서 1일때의 확률값을 추출. 
predictions = predictions.withColumn('1_proba', F.col('probability_arr')[1])
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr,1_proba
134,36431,831748,3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0,"List(0.8215437730181607, 0.17845622698183922)",0.1784562269818392
134,16953,831748,1,0,1,0.0,3,3,6.0,7.0,7.0,7.0,6.0,12.0,0.0212765957446808,0.0,3,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,10503,9530,0.5242849298657215,9530,20033,0.4757150701342784,12.2864745726266,0.0,30.0,81973,289400,0.2832515549412577,0.1924635151930206,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 6.0, 7.0, 7.0, 7.0, 6.0, 12.0, 0.02127659574468085, 0.0, 3.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 10503.0, 9530.0, 0.5242849298657215, 9530.0, 20033.0, 0.47571507013427844, 12.286474572626599, 0.0, 30.0, 81973.0, 289400.0, 0.28325155494125775, 0.19246351519302068))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0,"List(0.9372644078615597, 0.06273559213844024)",0.0627355921384402
153,14992,1658650,10,9,1,0.9,16,2,4.3,12.2,28.0,6.0,1.7,16.2,0.0374531835205992,0.0508474576271186,3,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,16942,12127,0.5828201864529223,12127,29069,0.4171798135470776,11.64868977804172,0.0,30.0,159213,3418021,0.0465804627882625,0.370599350758815,"Map(vectorType -> dense, length -> 44, values -> List(10.0, 9.0, 1.0, 0.9, 16.0, 2.0, 4.3, 12.2, 28.0, 6.0, 1.7, 16.2, 0.03745318352059925, 0.05084745762711865, 3.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 16942.0, 12127.0, 0.5828201864529223, 12127.0, 29069.0, 0.4171798135470776, 11.64868977804172, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.37059935075881506))","Map(vectorType -> dense, length -> 2, values -> List(15.340715909550475, 4.659284090449524))","Map(vectorType -> dense, length -> 2, values -> List(0.7670357954775238, 0.23296420452247618))",0.0,"List(0.7670357954775238, 0.23296420452247618)",0.2329642045224761
153,33653,1658650,1,0,1,0.0,10,10,6.0,7.0,7.0,7.0,1.0,12.0,0.0037453183520599,0.0,9,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,87,52,0.6258992805755396,52,139,0.3741007194244604,9.7109375,0.0,30.0,73124,266637,0.2742455098129667,0.0998552096114937,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 10.0, 10.0, 6.0, 7.0, 7.0, 7.0, 1.0, 12.0, 0.003745318352059925, 0.0, 9.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 87.0, 52.0, 0.6258992805755396, 52.0, 139.0, 0.37410071942446044, 9.7109375, 0.0, 30.0, 73124.0, 266637.0, 0.2742455098129667, 0.09985520961149374))","Map(vectorType -> dense, length -> 2, values -> List(18.853207545862247, 1.1467924541377548))","Map(vectorType -> dense, length -> 2, values -> List(0.9426603772931121, 0.05733962270688773))",0.0,"List(0.9426603772931121, 0.05733962270688773)",0.0573396227068877
153,21903,1658650,7,6,1,0.8571428571428571,18,9,5.285714285714286,12.857142857142858,28.0,6.0,1.0,16.0,0.0262172284644194,0.0338983050847457,1,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(7.0, 6.0, 1.0, 0.8571428571428571, 18.0, 9.0, 5.285714285714286, 12.857142857142858, 28.0, 6.0, 1.0, 16.0, 0.026217228464419477, 0.03389830508474576, 1.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(14.838104547828625, 5.161895452171375))","Map(vectorType -> dense, length -> 2, values -> List(0.7419052273914313, 0.2580947726085688))",0.0,"List(0.7419052273914313, 0.2580947726085688)",0.2580947726085688
153,38159,1658650,4,3,1,0.75,19,1,5.75,9.333333333333334,12.0,7.0,1.25,14.75,0.0149812734082397,0.0169491525423728,0,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,12789,8512,0.6003943476832073,8512,21301,0.3996056523167926,9.67922380080194,0.0,30.0,177141,3642188,0.048635874919142,0.3509697773976506,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 19.0, 1.0, 5.75, 9.333333333333334, 12.0, 7.0, 1.25, 14.75, 0.0149812734082397, 0.01694915254237288, 0.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 12789.0, 8512.0, 0.6003943476832073, 8512.0, 21301.0, 0.3996056523167926, 9.67922380080194, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.3509697773976506))","Map(vectorType -> dense, length -> 2, values -> List(17.320692924288814, 2.679307075711183))","Map(vectorType -> dense, length -> 2, values -> List(0.8660346462144408, 0.13396535378555918))",0.0,"List(0.8660346462144408, 0.13396535378555918)",0.1339653537855591
180,30480,2769561,1,0,1,0.0,4,4,32.0,28.0,28.0,28.0,6.0,10.0,0.0045045045045045,0.0,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,3062,3151,0.492837598583615,3151,6213,0.507162401416385,11.187958383080334,0.0,30.0,76476,297037,0.2574628749953709,0.249699526421014,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 32.0, 28.0, 28.0, 28.0, 6.0, 10.0, 0.0045045045045045045, 0.0, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 3062.0, 3151.0, 0.492837598583615, 3151.0, 6213.0, 0.507162401416385, 11.187958383080334, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24969952642101406))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)",0.0694661111052104
180,22963,2769561,5,4,1,0.8,6,1,4.8,26.0,30.0,18.0,5.2,10.0,0.0225225225225225,0.0373831775700934,0,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,18893,9491,0.6656214768883878,9491,28384,0.3343785231116122,10.781790169848188,0.0,30.0,78030,395130,0.1974793106066358,0.1368992125049764,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 4.8, 26.0, 30.0, 18.0, 5.2, 10.0, 0.02252252252252252, 0.037383177570093455, 0.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 18893.0, 9491.0, 0.6656214768883878, 9491.0, 28384.0, 0.3343785231116122, 10.781790169848188, 0.0, 30.0, 78030.0, 395130.0, 0.1974793106066358, 0.13689921250497641))","Map(vectorType -> dense, length -> 2, values -> List(14.360743568272248, 5.6392564317277545))","Map(vectorType -> dense, length -> 2, values -> List(0.7180371784136123, 0.2819628215863877))",0.0,"List(0.7180371784136123, 0.2819628215863877)",0.2819628215863877
180,3376,2769561,4,3,1,0.75,4,1,29.25,28.666666666666668,30.0,28.0,5.0,10.25,0.018018018018018,0.02803738317757,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,6281,6457,0.4930915371329879,6457,12738,0.5069084628670121,11.447830101569714,0.0,30.0,76476,297037,0.2574628749953709,0.2494455878716411,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 4.0, 1.0, 29.25, 28.666666666666668, 30.0, 28.0, 5.0, 10.25, 0.018018018018018018, 0.028037383177570093, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 6281.0, 6457.0, 0.4930915371329879, 6457.0, 12738.0, 0.5069084628670121, 11.447830101569714, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24944558787164112))","Map(vectorType -> dense, length -> 2, values -> List(17.033499090494846, 2.9665009095051533))","Map(vectorType -> dense, length -> 2, values -> List(0.8516749545247423, 0.14832504547525766))",0.0,"List(0.8516749545247423, 0.14832504547525766)",0.1483250454752576
186,24787,470997,1,0,1,0.0,2,2,5.0,18.0,18.0,18.0,3.0,17.0,0.0185185185185185,0.0,5,54,38,7,7.714285714285714,1.4210526315789471,0.7037037037037037,16,38,0.2962962962962963,15.959183673469388,30.0,3.0,3.1481481481481484,12.11111111111111,7,5.0,209,269,0.4372384937238494,269,478,0.5627615062761506,12.713302752293576,0.0,30.0,6631,15901,0.417017797622791,0.1457437086533596,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 2.0, 2.0, 5.0, 18.0, 18.0, 18.0, 3.0, 17.0, 0.018518518518518517, 0.0, 5.0, 54.0, 38.0, 7.0, 7.714285714285714, 1.4210526315789473, 0.7037037037037037, 16.0, 38.0, 0.2962962962962963, 15.959183673469388, 30.0, 3.0, 3.1481481481481484, 12.11111111111111, 7.0, 5.0, 209.0, 269.0, 0.4372384937238494, 269.0, 478.0, 0.5627615062761506, 12.713302752293577, 0.0, 30.0, 6631.0, 15901.0, 0.417017797622791, 0.14574370865335962))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)",0.0694661111052104


In [0]:
REORDER_THRESHOLD = 0.21
# 1_proba값이 REORDER_THRESHOLD보다 크면 1, 그렇지 않으면 0으로 reordered 컬럼 추가.
predictions = predictions.withColumn('reordered', (F.col('1_proba') > REORDER_THRESHOLD).cast('int')) 
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr,1_proba,reordered
134,36431,831748,3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0,"List(0.8215437730181607, 0.17845622698183922)",0.1784562269818392,0
134,16953,831748,1,0,1,0.0,3,3,6.0,7.0,7.0,7.0,6.0,12.0,0.0212765957446808,0.0,3,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,10503,9530,0.5242849298657215,9530,20033,0.4757150701342784,12.2864745726266,0.0,30.0,81973,289400,0.2832515549412577,0.1924635151930206,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 6.0, 7.0, 7.0, 7.0, 6.0, 12.0, 0.02127659574468085, 0.0, 3.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 10503.0, 9530.0, 0.5242849298657215, 9530.0, 20033.0, 0.47571507013427844, 12.286474572626599, 0.0, 30.0, 81973.0, 289400.0, 0.28325155494125775, 0.19246351519302068))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0,"List(0.9372644078615597, 0.06273559213844024)",0.0627355921384402,0
153,14992,1658650,10,9,1,0.9,16,2,4.3,12.2,28.0,6.0,1.7,16.2,0.0374531835205992,0.0508474576271186,3,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,16942,12127,0.5828201864529223,12127,29069,0.4171798135470776,11.64868977804172,0.0,30.0,159213,3418021,0.0465804627882625,0.370599350758815,"Map(vectorType -> dense, length -> 44, values -> List(10.0, 9.0, 1.0, 0.9, 16.0, 2.0, 4.3, 12.2, 28.0, 6.0, 1.7, 16.2, 0.03745318352059925, 0.05084745762711865, 3.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 16942.0, 12127.0, 0.5828201864529223, 12127.0, 29069.0, 0.4171798135470776, 11.64868977804172, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.37059935075881506))","Map(vectorType -> dense, length -> 2, values -> List(15.340715909550475, 4.659284090449524))","Map(vectorType -> dense, length -> 2, values -> List(0.7670357954775238, 0.23296420452247618))",0.0,"List(0.7670357954775238, 0.23296420452247618)",0.2329642045224761,1
153,33653,1658650,1,0,1,0.0,10,10,6.0,7.0,7.0,7.0,1.0,12.0,0.0037453183520599,0.0,9,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,87,52,0.6258992805755396,52,139,0.3741007194244604,9.7109375,0.0,30.0,73124,266637,0.2742455098129667,0.0998552096114937,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 10.0, 10.0, 6.0, 7.0, 7.0, 7.0, 1.0, 12.0, 0.003745318352059925, 0.0, 9.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 87.0, 52.0, 0.6258992805755396, 52.0, 139.0, 0.37410071942446044, 9.7109375, 0.0, 30.0, 73124.0, 266637.0, 0.2742455098129667, 0.09985520961149374))","Map(vectorType -> dense, length -> 2, values -> List(18.853207545862247, 1.1467924541377548))","Map(vectorType -> dense, length -> 2, values -> List(0.9426603772931121, 0.05733962270688773))",0.0,"List(0.9426603772931121, 0.05733962270688773)",0.0573396227068877,0
153,21903,1658650,7,6,1,0.8571428571428571,18,9,5.285714285714286,12.857142857142858,28.0,6.0,1.0,16.0,0.0262172284644194,0.0338983050847457,1,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(7.0, 6.0, 1.0, 0.8571428571428571, 18.0, 9.0, 5.285714285714286, 12.857142857142858, 28.0, 6.0, 1.0, 16.0, 0.026217228464419477, 0.03389830508474576, 1.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(14.838104547828625, 5.161895452171375))","Map(vectorType -> dense, length -> 2, values -> List(0.7419052273914313, 0.2580947726085688))",0.0,"List(0.7419052273914313, 0.2580947726085688)",0.2580947726085688,1
153,38159,1658650,4,3,1,0.75,19,1,5.75,9.333333333333334,12.0,7.0,1.25,14.75,0.0149812734082397,0.0169491525423728,0,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,12789,8512,0.6003943476832073,8512,21301,0.3996056523167926,9.67922380080194,0.0,30.0,177141,3642188,0.048635874919142,0.3509697773976506,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 19.0, 1.0, 5.75, 9.333333333333334, 12.0, 7.0, 1.25, 14.75, 0.0149812734082397, 0.01694915254237288, 0.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 12789.0, 8512.0, 0.6003943476832073, 8512.0, 21301.0, 0.3996056523167926, 9.67922380080194, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.3509697773976506))","Map(vectorType -> dense, length -> 2, values -> List(17.320692924288814, 2.679307075711183))","Map(vectorType -> dense, length -> 2, values -> List(0.8660346462144408, 0.13396535378555918))",0.0,"List(0.8660346462144408, 0.13396535378555918)",0.1339653537855591,0
180,30480,2769561,1,0,1,0.0,4,4,32.0,28.0,28.0,28.0,6.0,10.0,0.0045045045045045,0.0,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,3062,3151,0.492837598583615,3151,6213,0.507162401416385,11.187958383080334,0.0,30.0,76476,297037,0.2574628749953709,0.249699526421014,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 32.0, 28.0, 28.0, 28.0, 6.0, 10.0, 0.0045045045045045045, 0.0, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 3062.0, 3151.0, 0.492837598583615, 3151.0, 6213.0, 0.507162401416385, 11.187958383080334, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24969952642101406))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)",0.0694661111052104,0
180,22963,2769561,5,4,1,0.8,6,1,4.8,26.0,30.0,18.0,5.2,10.0,0.0225225225225225,0.0373831775700934,0,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,18893,9491,0.6656214768883878,9491,28384,0.3343785231116122,10.781790169848188,0.0,30.0,78030,395130,0.1974793106066358,0.1368992125049764,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 4.8, 26.0, 30.0, 18.0, 5.2, 10.0, 0.02252252252252252, 0.037383177570093455, 0.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 18893.0, 9491.0, 0.6656214768883878, 9491.0, 28384.0, 0.3343785231116122, 10.781790169848188, 0.0, 30.0, 78030.0, 395130.0, 0.1974793106066358, 0.13689921250497641))","Map(vectorType -> dense, length -> 2, values -> List(14.360743568272248, 5.6392564317277545))","Map(vectorType -> dense, length -> 2, values -> List(0.7180371784136123, 0.2819628215863877))",0.0,"List(0.7180371784136123, 0.2819628215863877)",0.2819628215863877,1
180,3376,2769561,4,3,1,0.75,4,1,29.25,28.666666666666668,30.0,28.0,5.0,10.25,0.018018018018018,0.02803738317757,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,6281,6457,0.4930915371329879,6457,12738,0.5069084628670121,11.447830101569714,0.0,30.0,76476,297037,0.2574628749953709,0.2494455878716411,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 4.0, 1.0, 29.25, 28.666666666666668, 30.0, 28.0, 5.0, 10.25, 0.018018018018018018, 0.028037383177570093, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 6281.0, 6457.0, 0.4930915371329879, 6457.0, 12738.0, 0.5069084628670121, 11.447830101569714, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24944558787164112))","Map(vectorType -> dense, length -> 2, values -> List(17.033499090494846, 2.9665009095051533))","Map(vectorType -> dense, length -> 2, values -> List(0.8516749545247423, 0.14832504547525766))",0.0,"List(0.8516749545247423, 0.14832504547525766)",0.1483250454752576,0
186,24787,470997,1,0,1,0.0,2,2,5.0,18.0,18.0,18.0,3.0,17.0,0.0185185185185185,0.0,5,54,38,7,7.714285714285714,1.4210526315789471,0.7037037037037037,16,38,0.2962962962962963,15.959183673469388,30.0,3.0,3.1481481481481484,12.11111111111111,7,5.0,209,269,0.4372384937238494,269,478,0.5627615062761506,12.713302752293576,0.0,30.0,6631,15901,0.417017797622791,0.1457437086533596,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 2.0, 2.0, 5.0, 18.0, 18.0, 18.0, 3.0, 17.0, 0.018518518518518517, 0.0, 5.0, 54.0, 38.0, 7.0, 7.714285714285714, 1.4210526315789473, 0.7037037037037037, 16.0, 38.0, 0.2962962962962963, 15.959183673469388, 30.0, 3.0, 3.1481481481481484, 12.11111111111111, 7.0, 5.0, 209.0, 269.0, 0.4372384937238494, 269.0, 478.0, 0.5627615062761506, 12.713302752293577, 0.0, 30.0, 6631.0, 15901.0, 0.417017797622791, 0.14574370865335962))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)",0.0694661111052104,0


In [0]:
# reordered가 1인 데이터만 추출. 
prediction_reordered = predictions.filter('reordered == 1')

In [0]:
#sample_submission.csv에 있는 데이터 로드하여, orders.csv에 있는 eval_set이 test이 데이터와 건수 비교. 
submission_sdf = spark.read.csv('/FileStore/tables/sample_submission.csv', header=True, inferSchema=True)
print(submission_sdf.count(), orders_sdf.filter("eval_set == 'test'").count())
display(submission_sdf)

75000 75000


order_id,products
17,39276 29259
34,39276 29259
137,39276 29259
182,39276 29259
257,39276 29259
313,39276 29259
353,39276 29259
386,39276 29259
414,39276 29259
418,39276 29259


In [0]:
test_orders_sdf = orders_sdf.filter("eval_set == 'test'")
display(test_orders_sdf.orderBy('order_id'))

order_id,user_id,eval_set,order_number,order_dow,order_hour_of_day,days_since_prior_order
17,36855,test,5,6,15,1.0
34,35220,test,20,3,11,8.0
137,187107,test,9,2,19,30.0
182,115892,test,28,0,11,8.0
257,35581,test,9,6,23,5.0
313,113359,test,31,6,22,7.0
353,173814,test,4,4,13,30.0
386,55492,test,8,0,15,30.0
414,120775,test,18,5,14,8.0
418,33565,test,12,0,12,14.0


In [0]:
submission_sdf.createOrReplaceTempView('submission')

In [0]:
%sql
-- test_data에서 submission에 없는 order가 있는지 확인
select count(*)
from test_data a
left outer join submission b
on a.order_id = b.order_id 
where b.order_id is null 

count(1)
0


In [0]:
display(predictions.limit(10))

user_id,product_id,order_id,up_cnt,up_reord_cnt,up_no_reord_cnt,up_reoredered_avg,up_max_ord_num,up_min_ord_num,up_avg_cart,up_avg_prior_days,up_max_prior_days,up_min_prior_days,up_avg_ord_dow,up_avg_ord_hour,up_usr_ratio,up_usr_reord_ratio,up_usr_ord_num_diff,usr_total_cnt,prd_uq_cnt,order_uq_cnt,usr_avg_prd_cnt,usr_avg_uq_prd_cnt,usr_uq_prd_ratio,usr_reord_cnt,usr_no_reord_cnt,usr_reordered_avg,usr_avg_prior_days,usr_max_prior_days,usr_min_prior_days,usr_avg_order_dow,usr_avg_order_hour_of_day,usr_max_order_number,days_since_prior_order,prd_reordered_cnt,prd_no_reordered_cnt,prd_avg_reordered,prd_unq_usr_cnt,prd_total_cnt,prd_usr_ratio,prd_avg_prior_days,prd_min_prior_days,prd_max_prior_days,aisle_distinct_usr_cnt,aisle_total_cnt,aisle_usr_ratio,usr_ratio_diff,features,rawPrediction,probability,prediction,probability_arr,1_proba,reordered
134,36431,831748,3,2,1,0.6666666666666666,4,2,6.0,14.333333333333334,21.0,7.0,4.0,10.666666666666666,0.0638297872340425,0.0952380952380952,2,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,2178,1792,0.5486146095717884,1792,3970,0.4513853904282116,12.340677499311484,0.0,30.0,73840,305655,0.2415795586527294,0.2098058317754821,"Map(vectorType -> dense, length -> 44, values -> List(3.0, 2.0, 1.0, 0.6666666666666666, 4.0, 2.0, 6.0, 14.333333333333334, 21.0, 7.0, 4.0, 10.666666666666666, 0.06382978723404255, 0.09523809523809523, 2.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 2178.0, 1792.0, 0.5486146095717884, 1792.0, 3970.0, 0.4513853904282116, 12.340677499311484, 0.0, 30.0, 73840.0, 305655.0, 0.2415795586527294, 0.20980583177548218))","Map(vectorType -> dense, length -> 2, values -> List(16.430875460363215, 3.569124539636784))","Map(vectorType -> dense, length -> 2, values -> List(0.8215437730181607, 0.17845622698183922))",0.0,"List(0.8215437730181607, 0.17845622698183922)",0.1784562269818392,0
134,16953,831748,1,0,1,0.0,3,3,6.0,7.0,7.0,7.0,6.0,12.0,0.0212765957446808,0.0,3,47,26,6,7.833333333333333,1.807692307692308,0.5531914893617021,21,26,0.4468085106382978,21.65853658536585,30.0,7.0,3.6595744680851054,8.191489361702128,6,30.0,10503,9530,0.5242849298657215,9530,20033,0.4757150701342784,12.2864745726266,0.0,30.0,81973,289400,0.2832515549412577,0.1924635151930206,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 3.0, 3.0, 6.0, 7.0, 7.0, 7.0, 6.0, 12.0, 0.02127659574468085, 0.0, 3.0, 47.0, 26.0, 6.0, 7.833333333333333, 1.8076923076923077, 0.5531914893617021, 21.0, 26.0, 0.44680851063829785, 21.658536585365855, 30.0, 7.0, 3.6595744680851063, 8.191489361702128, 6.0, 30.0, 10503.0, 9530.0, 0.5242849298657215, 9530.0, 20033.0, 0.47571507013427844, 12.286474572626599, 0.0, 30.0, 81973.0, 289400.0, 0.28325155494125775, 0.19246351519302068))","Map(vectorType -> dense, length -> 2, values -> List(18.745288157231194, 1.2547118427688049))","Map(vectorType -> dense, length -> 2, values -> List(0.9372644078615597, 0.06273559213844024))",0.0,"List(0.9372644078615597, 0.06273559213844024)",0.0627355921384402,0
153,14992,1658650,10,9,1,0.9,16,2,4.3,12.2,28.0,6.0,1.7,16.2,0.0374531835205992,0.0508474576271186,3,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,16942,12127,0.5828201864529223,12127,29069,0.4171798135470776,11.64868977804172,0.0,30.0,159213,3418021,0.0465804627882625,0.370599350758815,"Map(vectorType -> dense, length -> 44, values -> List(10.0, 9.0, 1.0, 0.9, 16.0, 2.0, 4.3, 12.2, 28.0, 6.0, 1.7, 16.2, 0.03745318352059925, 0.05084745762711865, 3.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 16942.0, 12127.0, 0.5828201864529223, 12127.0, 29069.0, 0.4171798135470776, 11.64868977804172, 0.0, 30.0, 159213.0, 3418021.0, 0.04658046278826256, 0.37059935075881506))","Map(vectorType -> dense, length -> 2, values -> List(15.340715909550475, 4.659284090449524))","Map(vectorType -> dense, length -> 2, values -> List(0.7670357954775238, 0.23296420452247618))",0.0,"List(0.7670357954775238, 0.23296420452247618)",0.2329642045224761,1
153,33653,1658650,1,0,1,0.0,10,10,6.0,7.0,7.0,7.0,1.0,12.0,0.0037453183520599,0.0,9,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,87,52,0.6258992805755396,52,139,0.3741007194244604,9.7109375,0.0,30.0,73124,266637,0.2742455098129667,0.0998552096114937,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 10.0, 10.0, 6.0, 7.0, 7.0, 7.0, 1.0, 12.0, 0.003745318352059925, 0.0, 9.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 87.0, 52.0, 0.6258992805755396, 52.0, 139.0, 0.37410071942446044, 9.7109375, 0.0, 30.0, 73124.0, 266637.0, 0.2742455098129667, 0.09985520961149374))","Map(vectorType -> dense, length -> 2, values -> List(18.853207545862247, 1.1467924541377548))","Map(vectorType -> dense, length -> 2, values -> List(0.9426603772931121, 0.05733962270688773))",0.0,"List(0.9426603772931121, 0.05733962270688773)",0.0573396227068877,0
153,21903,1658650,7,6,1,0.8571428571428571,18,9,5.285714285714286,12.857142857142858,28.0,6.0,1.0,16.0,0.0262172284644194,0.0338983050847457,1,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,186884,55037,0.7725001136734719,55037,241921,0.2274998863265281,11.199653703303918,0.0,30.0,159418,1765313,0.0903057984618025,0.1371940878647256,"Map(vectorType -> dense, length -> 44, values -> List(7.0, 6.0, 1.0, 0.8571428571428571, 18.0, 9.0, 5.285714285714286, 12.857142857142858, 28.0, 6.0, 1.0, 16.0, 0.026217228464419477, 0.03389830508474576, 1.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 186884.0, 55037.0, 0.7725001136734719, 55037.0, 241921.0, 0.2274998863265281, 11.199653703303918, 0.0, 30.0, 159418.0, 1765313.0, 0.09030579846180252, 0.1371940878647256))","Map(vectorType -> dense, length -> 2, values -> List(14.838104547828625, 5.161895452171375))","Map(vectorType -> dense, length -> 2, values -> List(0.7419052273914313, 0.2580947726085688))",0.0,"List(0.7419052273914313, 0.2580947726085688)",0.2580947726085688,1
153,38159,1658650,4,3,1,0.75,19,1,5.75,9.333333333333334,12.0,7.0,1.25,14.75,0.0149812734082397,0.0169491525423728,0,267,90,19,14.052631578947368,2.966666666666667,0.3370786516853932,177,90,0.6629213483146067,14.640151515151516,28.0,6.0,1.0337078651685394,16.123595505617978,19,7.0,12789,8512,0.6003943476832073,8512,21301,0.3996056523167926,9.67922380080194,0.0,30.0,177141,3642188,0.048635874919142,0.3509697773976506,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 19.0, 1.0, 5.75, 9.333333333333334, 12.0, 7.0, 1.25, 14.75, 0.0149812734082397, 0.01694915254237288, 0.0, 267.0, 90.0, 19.0, 14.052631578947368, 2.966666666666667, 0.33707865168539325, 177.0, 90.0, 0.6629213483146067, 14.640151515151516, 28.0, 6.0, 1.0337078651685394, 16.123595505617978, 19.0, 7.0, 12789.0, 8512.0, 0.6003943476832073, 8512.0, 21301.0, 0.3996056523167926, 9.67922380080194, 0.0, 30.0, 177141.0, 3642188.0, 0.04863587491914201, 0.3509697773976506))","Map(vectorType -> dense, length -> 2, values -> List(17.320692924288814, 2.679307075711183))","Map(vectorType -> dense, length -> 2, values -> List(0.8660346462144408, 0.13396535378555918))",0.0,"List(0.8660346462144408, 0.13396535378555918)",0.1339653537855591,0
180,30480,2769561,1,0,1,0.0,4,4,32.0,28.0,28.0,28.0,6.0,10.0,0.0045045045045045,0.0,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,3062,3151,0.492837598583615,3151,6213,0.507162401416385,11.187958383080334,0.0,30.0,76476,297037,0.2574628749953709,0.249699526421014,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 4.0, 4.0, 32.0, 28.0, 28.0, 28.0, 6.0, 10.0, 0.0045045045045045045, 0.0, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 3062.0, 3151.0, 0.492837598583615, 3151.0, 6213.0, 0.507162401416385, 11.187958383080334, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24969952642101406))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)",0.0694661111052104,0
180,22963,2769561,5,4,1,0.8,6,1,4.8,26.0,30.0,18.0,5.2,10.0,0.0225225225225225,0.0373831775700934,0,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,18893,9491,0.6656214768883878,9491,28384,0.3343785231116122,10.781790169848188,0.0,30.0,78030,395130,0.1974793106066358,0.1368992125049764,"Map(vectorType -> dense, length -> 44, values -> List(5.0, 4.0, 1.0, 0.8, 6.0, 1.0, 4.8, 26.0, 30.0, 18.0, 5.2, 10.0, 0.02252252252252252, 0.037383177570093455, 0.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 18893.0, 9491.0, 0.6656214768883878, 9491.0, 28384.0, 0.3343785231116122, 10.781790169848188, 0.0, 30.0, 78030.0, 395130.0, 0.1974793106066358, 0.13689921250497641))","Map(vectorType -> dense, length -> 2, values -> List(14.360743568272248, 5.6392564317277545))","Map(vectorType -> dense, length -> 2, values -> List(0.7180371784136123, 0.2819628215863877))",0.0,"List(0.7180371784136123, 0.2819628215863877)",0.2819628215863877,1
180,3376,2769561,4,3,1,0.75,4,1,29.25,28.666666666666668,30.0,28.0,5.0,10.25,0.018018018018018,0.02803738317757,2,222,115,6,37.0,1.930434782608696,0.5180180180180181,107,115,0.481981981981982,25.1264367816092,30.0,17.0,4.342342342342342,10.477477477477477,6,30.0,6281,6457,0.4930915371329879,6457,12738,0.5069084628670121,11.447830101569714,0.0,30.0,76476,297037,0.2574628749953709,0.2494455878716411,"Map(vectorType -> dense, length -> 44, values -> List(4.0, 3.0, 1.0, 0.75, 4.0, 1.0, 29.25, 28.666666666666668, 30.0, 28.0, 5.0, 10.25, 0.018018018018018018, 0.028037383177570093, 2.0, 222.0, 115.0, 6.0, 37.0, 1.9304347826086956, 0.5180180180180181, 107.0, 115.0, 0.481981981981982, 25.126436781609197, 30.0, 17.0, 4.342342342342342, 10.477477477477477, 6.0, 30.0, 6281.0, 6457.0, 0.4930915371329879, 6457.0, 12738.0, 0.5069084628670121, 11.447830101569714, 0.0, 30.0, 76476.0, 297037.0, 0.25746287499537096, 0.24944558787164112))","Map(vectorType -> dense, length -> 2, values -> List(17.033499090494846, 2.9665009095051533))","Map(vectorType -> dense, length -> 2, values -> List(0.8516749545247423, 0.14832504547525766))",0.0,"List(0.8516749545247423, 0.14832504547525766)",0.1483250454752576,0
186,24787,470997,1,0,1,0.0,2,2,5.0,18.0,18.0,18.0,3.0,17.0,0.0185185185185185,0.0,5,54,38,7,7.714285714285714,1.4210526315789471,0.7037037037037037,16,38,0.2962962962962963,15.959183673469388,30.0,3.0,3.1481481481481484,12.11111111111111,7,5.0,209,269,0.4372384937238494,269,478,0.5627615062761506,12.713302752293576,0.0,30.0,6631,15901,0.417017797622791,0.1457437086533596,"Map(vectorType -> dense, length -> 44, values -> List(1.0, 0.0, 1.0, 0.0, 2.0, 2.0, 5.0, 18.0, 18.0, 18.0, 3.0, 17.0, 0.018518518518518517, 0.0, 5.0, 54.0, 38.0, 7.0, 7.714285714285714, 1.4210526315789473, 0.7037037037037037, 16.0, 38.0, 0.2962962962962963, 15.959183673469388, 30.0, 3.0, 3.1481481481481484, 12.11111111111111, 7.0, 5.0, 209.0, 269.0, 0.4372384937238494, 269.0, 478.0, 0.5627615062761506, 12.713302752293577, 0.0, 30.0, 6631.0, 15901.0, 0.417017797622791, 0.14574370865335962))","Map(vectorType -> dense, length -> 2, values -> List(18.61067777789579, 1.3893222221042096))","Map(vectorType -> dense, length -> 2, values -> List(0.9305338888947896, 0.06946611110521048))",0.0,"List(0.9305338888947896, 0.06946611110521048)",0.0694661111052104,0


In [0]:
# predictions 결과를 order_id로 group by하여 개별 order별 예측 상품건수와 재 주문 상품 건수 계산. 
predictions_grp = predictions.groupby('order_id').agg(F.count('*').alias('total_cnt_by_order_id'), 
                                                      F.sum(F.col('reordered')).alias('reordered_cnt'))
print(predictions_grp.count(), predictions_grp.filter('reordered_cnt == 0').count())
display(predictions_grp.filter('reordered_cnt == 0').orderBy('order_id'))

75000 8982


order_id,total_cnt_by_order_id,reordered_cnt
353,12,0
474,21,0
513,16,0
1195,16,0
1564,20,0
1789,17,0
2297,10,0
3373,9,0
3519,15,0
4848,27,0


In [0]:
# collect_list()함수 결과 보기. 
import pyspark.sql.functions as F

display(predictions.filter('reordered == 1').groupBy('order_id').agg(F.collect_list('product_id')).limit(10))

order_id,collect_list(product_id)
34,"List(16083, 47029, 47766, 43504, 39180, 21137, 47792, 2596, 39475)"
137,"List(38689, 5134, 23794, 25890, 2326, 24852, 41787)"
182,"List(33000, 47672, 9337, 13629, 39275, 5479, 47209)"
386,"List(38281, 4920, 39180, 37935, 47766, 28985, 21479, 42265, 15872, 40759, 22124, 45066, 24852, 30450)"
497,"List(27275, 39947, 31964, 36316, 1831)"
604,"List(12099, 24852, 2962, 16797)"
758,List(19660)
887,"List(24852, 25647, 28204)"
1304,List(24852)
1802,"List(13076, 13176, 43295, 34969, 20114, 3896, 21137, 47209, 4920, 21709, 38313)"


In [0]:
# collect_list('product_id')로 입력되는 product_id list값을 ' '으로 결합된 문자열로 변환하는 함수 생성. 
def get_product_ids_str(product_id_group):
    #product_id_group은 collect_list('product_id')로 group by된 집합으로 product_id를 list로 가지고 있는 형태로 입력 됨. 
    product_ids_str = ''
    for product_id in product_id_group:
        product_ids_str += ' ' + str(product_id)
    
    return product_ids_str

In [0]:
from pyspark.sql.functions import udf,col
from pyspark.sql.types import StringType

# 일반 python용 UDF를 pyspark용 UDF로 변환. udf(lambda 입력변수: 일반 UDF, 해당 일반 UDF의 반환형)
udf_get_product_ids_str = udf(lambda x:get_product_ids_str(x), StringType() )

In [0]:
submission_01 = predictions.filter('reordered == 1').groupBy('order_id').agg(udf_get_product_ids_str(F.collect_list('product_id')).alias('products'))
display(submission_01.limit(10))

order_id,products
34,16083 47029 47766 43504 39180 21137 47792 2596 39475
137,38689 5134 23794 25890 2326 24852 41787
182,33000 47672 9337 13629 39275 5479 47209
386,38281 4920 39180 37935 47766 28985 21479 42265 15872 40759 22124 45066 24852 30450
497,27275 39947 31964 36316 1831
604,12099 24852 2962 16797
758,19660
887,24852 25647 28204
1304,24852
1802,13076 13176 43295 34969 20114 3896 21137 47209 4920 21709 38313


In [0]:
display(submission_01)

order_id,products
34,16083 47029 47766 43504 39180 21137 47792 2596 39475
137,38689 5134 23794 25890 2326 24852 41787
182,33000 47672 9337 13629 39275 5479 47209
386,38281 4920 39180 37935 47766 28985 21479 42265 15872 40759 22124 45066 24852 30450
497,27275 39947 31964 36316 1831
604,12099 24852 2962 16797
758,19660
887,24852 25647 28204
1304,24852
1802,13076 13176 43295 34969 20114 3896 21137 47209 4920 21709 38313


In [0]:
submission_02 = predictions_grp.filter('reordered_cnt == 0').withColumn('products', F.lit('None')).select('order_id', 'products')
display(submission_02.limit(10))

order_id,products
3380622,
488000,
2726972,
743061,
1873499,
490175,
1815351,
2625243,
2958747,
2912061,


In [0]:
submission = submission_01.union(submission_02)
print('submission count:', submission.count())
submission = submission.orderBy('order_id')

display(submission.limit(76000))

submission count: 75000


order_id,products
17,13107
34,16083 47029 47766 43504 39180 21137 47792 2596 39475
137,38689 5134 23794 25890 2326 24852 41787
182,33000 47672 9337 13629 39275 5479 47209
257,27966 24852 30233 13870 45013 4605 1025 27104 29837 49235
313,21903 13198 46906 12779 45007
353,
386,38281 4920 39180 37935 47766 28985 21479 42265 15872 40759 22124 45066 24852 30450
414,21230 21709 31730 20564 20392 27845 33320
418,30489 38694 47766 41950 40268


In [0]:
display(submission)

order_id,products
17,13107
34,16083 47029 47766 43504 39180 21137 47792 2596 39475
137,38689 5134 23794 25890 2326 24852 41787
182,33000 47672 9337 13629 39275 5479 47209
257,27966 24852 30233 13870 45013 4605 1025 27104 29837 49235
313,21903 13198 46906 12779 45007
353,
386,38281 4920 39180 37935 47766 28985 21479 42265 15872 40759 22124 45066 24852 30450
414,21230 21709 31730 20564 20392 27845 33320
418,30489 38694 47766 41950 40268


In [0]:
%fs
ls dbfs:/user/hive/warehouse/train_data

path,name,size
dbfs:/user/hive/warehouse/train_data/_delta_log/,_delta_log/,0
dbfs:/user/hive/warehouse/train_data/part-00000-1faada17-97d2-48d0-8c7e-ff76e4766d74-c000.snappy.parquet,part-00000-1faada17-97d2-48d0-8c7e-ff76e4766d74-c000.snappy.parquet,23658093
dbfs:/user/hive/warehouse/train_data/part-00001-4a8f890b-9125-4a9b-8772-b08c45fe114d-c000.snappy.parquet,part-00001-4a8f890b-9125-4a9b-8772-b08c45fe114d-c000.snappy.parquet,23675489
dbfs:/user/hive/warehouse/train_data/part-00002-f73ca159-4cac-4524-910d-df54ce0c3147-c000.snappy.parquet,part-00002-f73ca159-4cac-4524-910d-df54ce0c3147-c000.snappy.parquet,23520302
dbfs:/user/hive/warehouse/train_data/part-00003-231a2551-1ab4-40ea-86c2-2534b90b8759-c000.snappy.parquet,part-00003-231a2551-1ab4-40ea-86c2-2534b90b8759-c000.snappy.parquet,23668053
dbfs:/user/hive/warehouse/train_data/part-00004-d9dcc1ba-244f-44dd-a75a-dee6c4229fbe-c000.snappy.parquet,part-00004-d9dcc1ba-244f-44dd-a75a-dee6c4229fbe-c000.snappy.parquet,23551618
dbfs:/user/hive/warehouse/train_data/part-00005-c7eee720-0a4c-4074-a7a2-55a419b5358e-c000.snappy.parquet,part-00005-c7eee720-0a4c-4074-a7a2-55a419b5358e-c000.snappy.parquet,22797374
dbfs:/user/hive/warehouse/train_data/part-00006-1c3ea36c-d10a-4b90-bbaf-fc3a0bbdf1f1-c000.snappy.parquet,part-00006-1c3ea36c-d10a-4b90-bbaf-fc3a0bbdf1f1-c000.snappy.parquet,22855717
dbfs:/user/hive/warehouse/train_data/part-00007-094353b5-a751-44bc-a1be-eccea7291591-c000.snappy.parquet,part-00007-094353b5-a751-44bc-a1be-eccea7291591-c000.snappy.parquet,22432206
dbfs:/user/hive/warehouse/train_data/part-00008-2a500970-a1b9-4418-8d84-5c3a219fbcc7-c000.snappy.parquet,part-00008-2a500970-a1b9-4418-8d84-5c3a219fbcc7-c000.snappy.parquet,22363078


In [0]:
%sql
drop table if exists train_data;
CREATE TABLE train_data
USING parquet
OPTIONS (
       path "/user/hive/warehouse/train_data/" );

drop table if exists test_data;

CREATE TABLE test_data
USING parquet
OPTIONS (
       path "/user/hive/warehouse/test_data/" );    

In [0]:
spark.sql("set spark.databricks.delta.formatCheck.enabled=false")

Out[7]: DataFrame[key: string, value: string]