- ref:https://www.kaggle.com/kaerunantoka/h-m-eda-w-pyspark
- ref: https://qiita.com/t-yotsu/items/4cabd1ae5406cfd7d741
- ref: https://qiita.com/taka4sato/items/4ab2cf9e941599f1c0ca

In [1]:
!pip install pyspark > /dev/null



In [2]:
import pyspark
from pyspark.sql import SparkSession, SQLContext, Row
from pyspark.sql.window import Window

import pyspark.sql.functions as F 
import warnings
warnings.filterwarnings('ignore')

spark = SparkSession.builder.appName('h-and-m-personalized-fashion-recommendations').getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/03/26 03:53:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
class CFG:
    TRANSACTION_PATH = '../input/h-and-m-personalized-fashion-recommendations/transactions_train.csv'
    ARTICLE_PATH = '../input/h-and-m-personalized-fashion-recommendations/articles.csv'
    CUSTOMER_PATH = '../input/h-and-m-personalized-fashion-recommendations/customers.csv'
    SAMPLE_PATH = '../input/h-and-m-personalized-fashion-recommendations/sample_submission.csv'
    IMAGE_PATH = '../input/h-and-m-personalized-fashion-recommendations/images'

# log data

In [4]:
train = spark.read.option('header','true').csv(CFG.TRANSACTION_PATH)

                                                                                

In [5]:
train.count()

                                                                                

31788324

In [6]:
train.printSchema()

root
 |-- t_dat: string (nullable = true)
 |-- customer_id: string (nullable = true)
 |-- article_id: string (nullable = true)
 |-- price: string (nullable = true)
 |-- sales_channel_id: string (nullable = true)



In [7]:
train.select('sales_channel_id').distinct().show(5)



+----------------+
|sales_channel_id|
+----------------+
|               1|
|               2|
+----------------+



                                                                                

In [8]:
train.groupby('sales_channel_id').count().sort("count", ascending=False).show(5)



+----------------+--------+
|sales_channel_id|   count|
+----------------+--------+
|               2|22379862|
|               1| 9408462|
+----------------+--------+



                                                                                

In [9]:
train.registerTempTable("train_table") #dataframeにsqlテーブル名を付与
sqlContext = SQLContext(spark) #sparksqlを利用するには、sparkcontextだけでなくSQLContextも必要

sqlContext.sql(" SELECT * FROM train_table where sales_channel_id == 1 ").show(5, False)

+----------+----------------------------------------------------------------+----------+--------------------+----------------+
|t_dat     |customer_id                                                     |article_id|price               |sales_channel_id|
+----------+----------------------------------------------------------------+----------+--------------------+----------------+
|2018-09-20|00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280|0688873012|0.03049152542372881 |1               |
|2018-09-20|00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280|0501323011|0.053372881355932204|1               |
|2018-09-20|00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280|0688873020|0.03049152542372881 |1               |
|2018-09-20|00083cda041544b2fbb0e0d2905ad17da7cf1007526fb4c73235dccbbc132280|0688873011|0.03049152542372881 |1               |
|2018-09-20|001127bffdda108579e6cb16080440e89bf1250a776c6e55f56e35e9ee029a8d|0397068015|0.033881355932203386|1 

In [10]:
last2month_user = train.filter(train.t_dat > "2020-07-01").select("customer_id").distinct()
last2month_user.count()

                                                                                

484944

In [11]:
past_user = train.filter(train.t_dat <= "2020-06-30").select("customer_id").distinct()
past_user = past_user.join(last2month_user, past_user.customer_id == last2month_user.customer_id, "leftanti")
past_user.count()

                                                                                

876561

In [12]:
past_user.show()



+--------------------+
|         customer_id|
+--------------------+
|000064249685c1155...|
|0001ab2ebc1bb9a21...|
|000226b9ea8101924...|
|000253f6914890557...|
|0002697f519fce0a4...|
|00028f80bf6da2c28...|
|0003bf22bd16afdff...|
|0003e56a4332b2503...|
|0003e9bbb9faf3937...|
|000406f7566f4c298...|
|00048f2f68760664d...|
|0004ec078693c417c...|
|00053308738e0d6bc...|
|00058bbdde20e5d34...|
|00059eadf61683b9b...|
|0005baed366933727...|
|0005c68366e795568...|
|0005f8253fe3d050a...|
|000608ab13228c9d4...|
|000775da04900bb53...|
+--------------------+
only showing top 20 rows



                                                                                

In [13]:
train_past_user = train.join(past_user, train.customer_id == past_user.customer_id, "inner").drop(past_user.customer_id)

In [14]:
train_past_user.show(5, False)

[Stage 37:>                                                         (0 + 1) / 1]

+----------+----------+--------------------+----------------+----------------------------------------------------------------+
|t_dat     |article_id|price               |sales_channel_id|customer_id                                                     |
+----------+----------+--------------------+----------------+----------------------------------------------------------------+
|2019-10-02|0738133005|0.016932203389830508|2               |000064249685c11552da43ef22a5030f35a147f723d5b02ddd9fd22452b1f5a6|
|2019-10-02|0680265002|0.042355932203389825|2               |000064249685c11552da43ef22a5030f35a147f723d5b02ddd9fd22452b1f5a6|
|2019-10-02|0740962001|0.042355932203389825|2               |000064249685c11552da43ef22a5030f35a147f723d5b02ddd9fd22452b1f5a6|
|2019-01-25|0682520002|0.024389830508474576|1               |0003e56a4332b2503e34640be92031ad48f1280ee6e3a7f6b7b94664383facdc|
|2018-09-29|0587026013|0.040661016949152536|1               |00048f2f68760664d2d0fa1e7fbfe083f05287f342484c29a1

                                                                                

In [15]:
train.groupBy("customer_id").count().sort("count", ascending=False).show(5)



+--------------------+-----+
|         customer_id|count|
+--------------------+-----+
|be1981ab818cf4ef6...| 1895|
|b4db5e5259234574e...| 1441|
|49beaacac0c7801c2...| 1364|
|a65f77281a528bf5c...| 1361|
|cd04ec2726dd58a8c...| 1237|
+--------------------+-----+
only showing top 5 rows



                                                                                

In [16]:
train.groupBy("t_dat").count().sort("count", ascending=False).show(5)



+----------+------+
|     t_dat| count|
+----------+------+
|2019-09-28|198622|
|2020-04-11|162799|
|2019-11-29|160875|
|2018-11-23|142018|
|2018-09-29|141700|
+----------+------+
only showing top 5 rows



                                                                                

In [17]:
# groupby + aggは使える
train.groupBy("customer_id").agg({"price": "sum"}).show(3)



+--------------------+-------------------+
|         customer_id|         sum(price)|
+--------------------+-------------------+
|05f65801b9a2d28a5...| 1.7262881355932203|
|05f79b715286a38a8...| 0.3919322033898305|
|072d11a8c0a1e6d0f...|0.38801694915254237|
+--------------------+-------------------+
only showing top 3 rows



                                                                                

In [18]:
agged_df = train.groupBy("customer_id", "t_dat").count()
agged_df.show(3)



+--------------------+----------+-----+
|         customer_id|     t_dat|count|
+--------------------+----------+-----+
|00b5bd5358a051556...|2018-09-20|    3|
|00be0a263381af381...|2018-09-20|    1|
|023b48de81f6af9de...|2018-09-20|    9|
+--------------------+----------+-----+
only showing top 3 rows



                                                                                

In [19]:
train.withColumn("real_price", train.price * 590).show(5, False)

+----------+----------------------------------------------------------------+----------+--------------------+----------------+------------------+
|t_dat     |customer_id                                                     |article_id|price               |sales_channel_id|real_price        |
+----------+----------------------------------------------------------------+----------+--------------------+----------------+------------------+
|2018-09-20|000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318|0663713001|0.050830508474576264|2               |29.989999999999995|
|2018-09-20|000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8ffe7ad4a1091e318|0541518023|0.03049152542372881 |2               |17.99             |
|2018-09-20|00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2|0505221004|0.01523728813559322 |2               |8.99              |
|2018-09-20|00007d2de826758b65a93dd24ce629ed66842531df6699338c5570910a014cc2|0685687003|0.016932203389830508|2              

In [20]:
train.columns

['t_dat', 'customer_id', 'article_id', 'price', 'sales_channel_id']

In [21]:
train.groupBy("customer_id").agg({"price": "sum"}).show(3)



+--------------------+-------------------+
|         customer_id|         sum(price)|
+--------------------+-------------------+
|05f65801b9a2d28a5...| 1.7262881355932203|
|05f79b715286a38a8...| 0.3919322033898305|
|072d11a8c0a1e6d0f...|0.38801694915254237|
+--------------------+-------------------+
only showing top 3 rows



                                                                                

# article

In [22]:
article = spark.read.option('header','true').csv(CFG.ARTICLE_PATH)

In [23]:
article.filter(article["department_name"] == "Jacket Casual").show(5)

+----------+------------+------------+---------------+-----------------+------------------+-----------------------+-------------------------+-----------------+-----------------+-------------------------+---------------------------+--------------------------+----------------------------+-------------+---------------+----------+----------+--------------+----------------+----------+--------------+----------------+------------------+--------------------+
|article_id|product_code|   prod_name|product_type_no|product_type_name|product_group_name|graphical_appearance_no|graphical_appearance_name|colour_group_code|colour_group_name|perceived_colour_value_id|perceived_colour_value_name|perceived_colour_master_id|perceived_colour_master_name|department_no|department_name|index_code|index_name|index_group_no|index_group_name|section_no|  section_name|garment_group_no|garment_group_name|         detail_desc|
+----------+------------+------------+---------------+-----------------+------------------

# customer

In [24]:
customer = spark.read.option('header','true').csv(CFG.CUSTOMER_PATH)

In [25]:
customer.show(5, False)

+----------------------------------------------------------------+----+------+------------------+----------------------+---+----------------------------------------------------------------+
|customer_id                                                     |FN  |Active|club_member_status|fashion_news_frequency|age|postal_code                                                     |
+----------------------------------------------------------------+----+------+------------------+----------------------+---+----------------------------------------------------------------+
|00000dbacae5abe5e23885899a1fa44253a17956c6d1c3d25f88aa139fdfc657|null|null  |ACTIVE            |NONE                  |49 |52043ee2162cf5aa7ee79974281641c6f11a68d276429a91f8ca0d4b6efa8100|
|0000423b00ade91418cceaf3b26c6af3dd342b51fd051eec9c12fb36984420fa|null|null  |ACTIVE            |NONE                  |25 |2973abc54daa8a5f8ccfe9362140c63247c5eee03f1d93f4c830291c32bc3057|
|000058a12d5b43e67d225668fa1f8d618c13dc232df0cad8f

In [26]:
customer.groupby("FN").count().sort("count", ascending=False).show(5)

[Stage 58:>                                                         (0 + 4) / 4]

+----+------+
|  FN| count|
+----+------+
|null|895050|
| 1.0|476930|
+----+------+



                                                                                

In [27]:
customer.groupby("age").count().sort("count", ascending=False).show(5)

[Stage 61:>                                                         (0 + 4) / 4]

+---+-----+
|age|count|
+---+-----+
| 21|67530|
| 24|56124|
| 20|55196|
| 25|54989|
| 23|54867|
+---+-----+
only showing top 5 rows



                                                                                

In [28]:
customer.select([F.count(F.when(F.isnan(c), c)).alias(c) for c in customer.columns]).show()



+-----------+---+------+------------------+----------------------+---+-----------+
|customer_id| FN|Active|club_member_status|fashion_news_frequency|age|postal_code|
+-----------+---+------+------------------+----------------------+---+-----------+
|          0|  0|     0|                 0|                     0|  0|          0|
+-----------+---+------+------------------+----------------------+---+-----------+



                                                                                

# purchase history