# Домашнее задание № 7, Кривоногов Н.В.

In [1]:
# !pip install pyspark

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import pandas as pd
import numpy as np

# import matplotlib.pyplot as plt
# %matplotlib inline

from pyspark.ml.recommendation import ALS
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType
import pyspark.sql.functions as sf

# Для работы с матрицами
from scipy.sparse import csr_matrix

# Матричная факторизация
from implicit.als import AlternatingLeastSquares
from implicit.nearest_neighbours import bm25_weight, tfidf_weight

# Модель второго уровня
# from lightgbm import LGBMClassifier

import os, sys

module_path = os.path.abspath(os.path.join(os.pardir))
if module_path not in sys.path:
    sys.path.append(module_path)

# Написанные нами функции
from metrics import precision_at_k, recall_at_k
from utils import prefilter_items
# from recommenders import MainRecommender

In [4]:
data = pd.read_csv('retail_train.csv')
item_features = pd.read_csv('product.csv')
user_features = pd.read_csv('hh_demographic.csv')

# column processing
item_features.columns = [col.lower() for col in item_features.columns]
user_features.columns = [col.lower() for col in user_features.columns]

item_features.rename(columns={'product_id': 'item_id'}, inplace=True)
user_features.rename(columns={'household_key': 'user_id'}, inplace=True)

# train test split
test_size_weeks = 3

data_train = data[data['week_no'] < data['week_no'].max() - test_size_weeks]
data_test = data[data['week_no'] >= data['week_no'].max() - test_size_weeks]

data_train.head(2)

Unnamed: 0,user_id,basket_id,day,item_id,quantity,sales_value,store_id,retail_disc,trans_time,week_no,coupon_disc,coupon_match_disc
0,2375,26984851472,1,1004906,1,1.39,364,-0.6,1631,1,0.0,0.0
1,2375,26984851472,1,1033142,1,0.82,364,0.0,1631,1,0.0,0.0


In [5]:
n_items_before = data_train['item_id'].nunique()

# data_train = prefilter_items(data_train, item_features)

n_items_after = data_train['item_id'].nunique()
print('Decreased # items from {} to {}'.format(n_items_before, n_items_after))

Decreased # items from 86865 to 86865


In [6]:
popularity = data_train.groupby('item_id')['quantity'].sum().reset_index()
popularity.rename(columns={'quantity': 'n_sold'}, inplace=True)

top_5000 = popularity.sort_values('n_sold', ascending=False).head(5000).item_id.tolist()

In [7]:
data_train.loc[~data_train['item_id'].isin(top_5000), 'item_id'] = 999999

In [8]:
user_item_matrix = pd.pivot_table(data_train, 
                                  index='user_id', columns='item_id', 
                                  values='quantity', # Можно пробовать другие варианты
                                  aggfunc='count', 
                                  fill_value=0
                                 )

user_item_matrix = user_item_matrix.astype(float) # необходимый тип матрицы для implicit

# переведем в формат sparse matrix
sparse_user_item = csr_matrix(user_item_matrix).tocsr()

user_item_matrix.head(2)

item_id,202291,397896,420647,480014,545926,707683,731106,818980,819063,819227,...,15778533,15831255,15926712,15926775,15926844,15926886,15927403,15927661,15927850,16809471
user_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [9]:
data_test = data_test[data_test['item_id'].isin(data_train['item_id'].unique())]

In [10]:
result = data_test.groupby('user_id')['item_id'].unique().reset_index()
result.columns=['user_id', 'actual']
result.head(2)

Unnamed: 0,user_id,actual
0,1,"[821867, 834484, 856942, 865456, 914190, 95804..."
1,3,"[851057, 872021, 878302, 879948, 909638, 91320..."


In [11]:
userids = user_item_matrix.index.values
itemids = user_item_matrix.columns.values

matrix_userids = np.arange(len(userids))
matrix_itemids = np.arange(len(itemids))

id_to_itemid = dict(zip(matrix_itemids, itemids))
id_to_userid = dict(zip(matrix_userids, userids))

itemid_to_id = dict(zip(itemids, matrix_itemids))
userid_to_id = dict(zip(userids, matrix_userids))

SparkSession

In [12]:
session = (
        SparkSession.builder.config('spark.driver.memory', '1g')
        .config('spark.sql.shuffle.partitions', '100')
        .config('spark.driver.bindAddress', '127.0.0.1')
        .config('spark.driver.host', 'localhost')
        .master('local[1]')
        .enableHiveSupport()
        .getOrCreate()
        )

In [13]:
session

In [14]:
# data_train['item_idx'] = data_train['item_id'].map(lambda x:itemid_to_id[x])
# data_train['user_idx'] = data_train['user_id'].map(lambda x:userid_to_id[x])

In [15]:
# !pip install pyarrow

In [16]:
%%time

spark_data_train = session.createDataFrame(data_train[['user_id', 'item_id', 'quantity']])

Wall time: 2min


In [17]:
spark_data_train = spark_data_train.withColumnRenamed('quantity', 'relevance')

In [18]:
# spark_data_train.show(10)

In [19]:
# ---------------------------------------------------------------------------
# Py4JJavaError                             Traceback (most recent call last)
# ~\AppData\Local\Temp\ipykernel_1552\3177425222.py in <module>
# ----> 1 spark_data_train.show(10)

# C:\Conda\lib\site-packages\pyspark\sql\dataframe.py in show(self, n, truncate, vertical)
#     604 
#     605         if isinstance(truncate, bool) and truncate:
# --> 606             print(self._jdf.showString(n, 20, vertical))
#     607         else:
#     608             try:

# C:\Conda\lib\site-packages\py4j\java_gateway.py in __call__(self, *args)
#    1319 
#    1320         answer = self.gateway_client.send_command(command)
# -> 1321         return_value = get_return_value(
#    1322             answer, self.gateway_client, self.target_id, self.name)
#    1323 

# C:\Conda\lib\site-packages\pyspark\sql\utils.py in deco(*a, **kw)
#     188     def deco(*a: Any, **kw: Any) -> Any:
#     189         try:
# --> 190             return f(*a, **kw)
#     191         except Py4JJavaError as e:
#     192             converted = convert_exception(e.java_exception)

# C:\Conda\lib\site-packages\py4j\protocol.py in get_return_value(answer, gateway_client, target_id, name)
#     324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
#     325             if answer[1] == REFERENCE_TYPE:
# --> 326                 raise Py4JJavaError(
#     327                     "An error occurred while calling {0}{1}{2}.\n".
#     328                     format(target_id, ".", name), value)

# Py4JJavaError: An error occurred while calling o58.showString.
# : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0) (noutnik executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:189)
# 	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
# 	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
# 	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:157)
# 	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
# 	at org.apache.spark.scheduler.Task.run(Task.scala:136)
# 	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
# 	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
# 	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
# 	at java.base/java.lang.Thread.run(Thread.java:1623)
# Caused by: java.net.SocketTimeoutException: Accept timed out
# 	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:694)
# 	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:738)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:690)
# 	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:655)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:631)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:588)
# 	at java.base/java.net.ServerSocket.accept(ServerSocket.java:546)
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:176)
# 	... 29 more

# Driver stacktrace:
# 	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
# 	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
# 	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
# 	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
# 	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
# 	at scala.Option.foreach(Option.scala:407)
# 	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
# 	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
# 	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
# 	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
# 	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
# 	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
# 	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2238)
# 	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2259)
# 	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2278)
# 	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:506)
# 	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:459)
# 	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:48)
# 	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3868)
# 	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2863)
# 	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:3858)
# 	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:510)
# 	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3856)
# 	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
# 	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
# 	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
# 	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
# 	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
# 	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3856)
# 	at org.apache.spark.sql.Dataset.head(Dataset.scala:2863)
# 	at org.apache.spark.sql.Dataset.take(Dataset.scala:3084)
# 	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:288)
# 	at org.apache.spark.sql.Dataset.showString(Dataset.scala:327)
# 	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
# 	at java.base/java.lang.reflect.Method.invoke(Method.java:578)
# 	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
# 	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
# 	at py4j.Gateway.invoke(Gateway.java:282)
# 	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
# 	at py4j.commands.CallCommand.execute(CallCommand.java:79)
# 	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
# 	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
# 	at java.base/java.lang.Thread.run(Thread.java:1623)
# Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:189)
# 	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
# 	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
# 	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:157)
# 	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
# 	at org.apache.spark.scheduler.Task.run(Task.scala:136)
# 	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
# 	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
# 	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
# 	... 1 more
# Caused by: java.net.SocketTimeoutException: Accept timed out
# 	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:694)
# 	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:738)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:690)
# 	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:655)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:631)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:588)
# 	at java.base/java.net.ServerSocket.accept(ServerSocket.java:546)
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:176)
# 	... 29 more



In [20]:
# model = ALS(
#             rank=30,
#             userCol='user_id',
#             itemCol='item_id',
#             ratingCol='relevance',
#             maxIter=10,
#             alpha=1.0,
#             regParam=0.1,
#             implicitPrefs=True,
#             seed=42,
#             coldStartStrategy='drop',
#         ).fit(spark_data_train)

In [21]:
# ---------------------------------------------------------------------------
# Py4JJavaError                             Traceback (most recent call last)
# ~\AppData\Local\Temp\ipykernel_1552\1727582100.py in <module>
# ----> 1 model = ALS(
#       2             rank=30,
#       3             userCol='user_id',
#       4             itemCol='item_id',
#       5             ratingCol='relevance',

# C:\Conda\lib\site-packages\pyspark\ml\base.py in fit(self, dataset, params)
#     203                 return self.copy(params)._fit(dataset)
#     204             else:
# --> 205                 return self._fit(dataset)
#     206         else:
#     207             raise TypeError(

# C:\Conda\lib\site-packages\pyspark\ml\wrapper.py in _fit(self, dataset)
#     381 
#     382     def _fit(self, dataset: DataFrame) -> JM:
# --> 383         java_model = self._fit_java(dataset)
#     384         model = self._create_model(java_model)
#     385         return self._copyValues(model)

# C:\Conda\lib\site-packages\pyspark\ml\wrapper.py in _fit_java(self, dataset)
#     378 
#     379         self._transfer_params_to_java()
# --> 380         return self._java_obj.fit(dataset._jdf)
#     381 
#     382     def _fit(self, dataset: DataFrame) -> JM:

# C:\Conda\lib\site-packages\py4j\java_gateway.py in __call__(self, *args)
#    1319 
#    1320         answer = self.gateway_client.send_command(command)
# -> 1321         return_value = get_return_value(
#    1322             answer, self.gateway_client, self.target_id, self.name)
#    1323 

# C:\Conda\lib\site-packages\pyspark\sql\utils.py in deco(*a, **kw)
#     188     def deco(*a: Any, **kw: Any) -> Any:
#     189         try:
# --> 190             return f(*a, **kw)
#     191         except Py4JJavaError as e:
#     192             converted = convert_exception(e.java_exception)

# C:\Conda\lib\site-packages\py4j\protocol.py in get_return_value(answer, gateway_client, target_id, name)
#     324             value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
#     325             if answer[1] == REFERENCE_TYPE:
# --> 326                 raise Py4JJavaError(
#     327                     "An error occurred while calling {0}{1}{2}.\n".
#     328                     format(target_id, ".", name), value)

# Py4JJavaError: An error occurred while calling o62.fit.
# : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 1) (noutnik executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:189)
# 	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
# 	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
# 	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:157)
# 	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.sql.execution.SQLExecutionRDD.$anonfun$compute$1(SQLExecutionRDD.scala:52)
# 	at org.apache.spark.sql.internal.SQLConf$.withExistingConf(SQLConf.scala:158)
# 	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
# 	at org.apache.spark.scheduler.Task.run(Task.scala:136)
# 	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
# 	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
# 	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
# 	at java.base/java.lang.Thread.run(Thread.java:1623)
# Caused by: java.net.SocketTimeoutException: Accept timed out
# 	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:694)
# 	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:738)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:690)
# 	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:655)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:631)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:588)
# 	at java.base/java.net.ServerSocket.accept(ServerSocket.java:546)
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:176)
# 	... 40 more

# Driver stacktrace:
# 	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
# 	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
# 	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
# 	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
# 	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
# 	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
# 	at scala.Option.foreach(Option.scala:407)
# 	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
# 	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
# 	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
# 	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
# 	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
# 	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
# 	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2238)
# 	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2259)
# 	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2278)
# 	at org.apache.spark.rdd.RDD.$anonfun$take$1(RDD.scala:1470)
# 	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
# 	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
# 	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
# 	at org.apache.spark.rdd.RDD.take(RDD.scala:1443)
# 	at org.apache.spark.rdd.RDD.$anonfun$isEmpty$1(RDD.scala:1578)
# 	at scala.runtime.java8.JFunction0$mcZ$sp.apply(JFunction0$mcZ$sp.java:23)
# 	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
# 	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
# 	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
# 	at org.apache.spark.rdd.RDD.isEmpty(RDD.scala:1578)
# 	at org.apache.spark.ml.recommendation.ALS$.train(ALS.scala:960)
# 	at org.apache.spark.ml.recommendation.ALS.$anonfun$fit$1(ALS.scala:722)
# 	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
# 	at scala.util.Try$.apply(Try.scala:213)
# 	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
# 	at org.apache.spark.ml.recommendation.ALS.fit(ALS.scala:704)
# 	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
# 	at java.base/java.lang.reflect.Method.invoke(Method.java:578)
# 	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
# 	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
# 	at py4j.Gateway.invoke(Gateway.java:282)
# 	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
# 	at py4j.commands.CallCommand.execute(CallCommand.java:79)
# 	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
# 	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
# 	at java.base/java.lang.Thread.run(Thread.java:1623)
# Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:189)
# 	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
# 	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
# 	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:157)
# 	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:65)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.sql.execution.SQLExecutionRDD.$anonfun$compute$1(SQLExecutionRDD.scala:52)
# 	at org.apache.spark.sql.internal.SQLConf$.withExistingConf(SQLConf.scala:158)
# 	at org.apache.spark.sql.execution.SQLExecutionRDD.compute(SQLExecutionRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
# 	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
# 	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
# 	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
# 	at org.apache.spark.scheduler.Task.run(Task.scala:136)
# 	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
# 	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
# 	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
# 	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
# 	... 1 more
# Caused by: java.net.SocketTimeoutException: Accept timed out
# 	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:694)
# 	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:738)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:690)
# 	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:655)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:631)
# 	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:588)
# 	at java.base/java.net.ServerSocket.accept(ServerSocket.java:546)
# 	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:176)
# 	... 40 more



In [22]:
# recs_als = model.recommendForAllUsers(5)

In [23]:
# recs_als.show()

In [24]:
item_features = pd.read_csv('product.csv')
item_features.columns = [col.lower() for col in item_features.columns]
item_features.rename(columns={'product_id': 'item_id'}, inplace=True)

item_features.head(2)

Unnamed: 0,item_id,manufacturer,department,brand,commodity_desc,sub_commodity_desc,curr_size_of_product
0,25671,2,GROCERY,National,FRZN ICE,ICE - CRUSHED/CUBED,22 LB
1,26081,2,MISC. TRANS.,National,NO COMMODITY DESCRIPTION,NO SUBCOMMODITY DESCRIPTION,


In [25]:
item_features['department'].nunique()

44

In [26]:
item_features['commodity_desc'].nunique()

308

In [27]:
item_features['sub_commodity_desc'].nunique()

2383

In [28]:
# берем с запасом
recommendations = [ 26738, 26738, 26941, 25671, 26081, 26093, 18293696, 18294080, 18316298, 29247, 29252, 29340]

In [29]:
def postfilter(recommendations, item_info, N=5):
    """Пост-фильтрация товаров
    
    Input
    -----
    recommendations: list
        Ранжированный список item_id для рекомендаций
    item_info: pd.DataFrame
        Датафрейм с информацией о товарах
    """
    
    # Уникальность
#     recommendations = list(set(recommendations)) - неверно! так теряется порядок
    unique_recommendations = []
    [unique_recommendations.append(item) for item in recommendations if item not in unique_recommendations]
    
    # Разные категории
    categories_used = []
    final_recommendations = []
    
    CATEGORY_NAME = 'sub_commodity_desc'
    for item in unique_recommendations:
        category = item_features.loc[item_features['item_id'] == item, CATEGORY_NAME].values[0]
        
        if category not in categories_used:
            final_recommendations.append(item)
            
        unique_recommendations.remove(item)
        categories_used.append(category)
    
    n_rec = len(final_recommendations)
    if n_rec < N:
        final_recommendations.extend(unique_recommendations[:N - n_rec])
    else:
        final_recommendations = final_recommendations[:N]
    
    assert len(final_recommendations) == N, 'Количество рекомендаций != {}'.format(N)
    return final_recommendations

In [30]:
postfilter(recommendations, item_info=item_features, N=5)

[26738, 25671, 26093, 18294080, 29247]