In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pyspark.sql.session import SparkSession
import os
import logging

spark_memory = "256g"
spark_cores = "*"
user_home = os.environ["HOME"]

spark = (
    SparkSession
    .builder
    .config("spark.driver.memory", spark_memory)
    .config("spark.local.dir", os.path.join(user_home, "tmp"))
    .master(f"local[{spark_cores}]")
    .enableHiveSupport()
    .getOrCreate()
)

spark_logger = logging.getLogger("py4j")
spark_logger.setLevel(logging.WARN)
logger = logging.getLogger()
formatter = logging.Formatter("%(asctime)s, %(name)s, %(levelname)s: %(message)s",
                              datefmt="%d-%b-%y %H:%M:%S")
hdlr = logging.StreamHandler()
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logger.setLevel(logging.DEBUG)
spark

In [3]:
from sponge_bob_magic.data_preparator.data_preparator import DataPreparator

data_preparator = DataPreparator(spark)

In [4]:
# MovieLens 20M
log = data_preparator.transform_log(
    os.path.join(user_home, "data/ml-20m/ratings.csv"),
    format_type="csv",
    columns_names={
        "user_id": "userId",
        "item_id": "movieId",
        "relevance": "rating",
        "timestamp": "timestamp"
    },
    date_format=None,
    sep=",",
    header=True
)

In [5]:
from pyspark.sql.functions import min, max

log.agg(min("timestamp"), max("timestamp")).head()

Row(min(timestamp)=datetime.datetime(1995, 1, 9, 14, 46, 44), max(timestamp)=datetime.datetime(2015, 3, 31, 9, 40, 2))

In [6]:
%%time
from sponge_bob_magic.splitters.log_splitter import LogSplitRandomlySplitter

train, test_input, test = LogSplitRandomlySplitter(
    spark, 0.2, None
).split(log, drop_cold_users=True, drop_cold_items=True)
print(
    train.count(), 
    test_input.count(), 
    test.count()
)

15998563 15998563 4000703
CPU times: user 4 ms, sys: 8 ms, total: 12 ms
Wall time: 21.8 s


In [7]:
from sponge_bob_magic.metrics.metrics import Metrics

metrics = Metrics()

In [8]:
%%time
from sponge_bob_magic.models.popular_recomennder import PopularRecommender

popular_recomennder = PopularRecommender(spark, 0, 0)
recs = popular_recomennder.fit_predict(
    k=10,
    users=test.select("user_id").distinct().cache(),
    items=test.select("item_id").distinct().cache(),
    log=train.cache(),
    context=None,
    user_features=None,
    item_features=None,
    path=os.path.join(user_home, "models/popular.model")
).cache()
print(metrics.hit_rate_at_k(recs, test.cache(), 10))
print(metrics.recall_at_k(recs, test.cache(), 10))
print(metrics.precision_at_k(recs, test.cache(), 10))
print(metrics.ndcg_at_k(recs, test.cache(), 10))

17-Dec-19 11:15:37, root, DEBUG: Проверка датафреймов
17-Dec-19 11:15:40, root, DEBUG: Предварительная стадия обучения (pre-fit)
17-Dec-19 11:15:42, root, DEBUG: Среднее количество items у каждого user: 116
17-Dec-19 11:15:44, root, DEBUG: Основная стадия обучения (fit)
17-Dec-19 11:15:44, root, DEBUG: Проверка датафреймов
17-Dec-19 11:16:00, root, DEBUG: Количество items после фильтрации: 126


0.6074002848487916
0.07542818630124039
0.15879669459727155
0.18316094716688788
CPU times: user 324 ms, sys: 60 ms, total: 384 ms
Wall time: 1min 25s


In [9]:
%%time
from sponge_bob_magic.models.knn_recommender import KNNRecommender

knn_recommender = KNNRecommender(spark, 30)
recs = knn_recommender.fit_predict(
    k=10,
    users=test.select("user_id").distinct().cache(),
    items=test.select("item_id").distinct().cache(),
    log=train.cache(),
    context=None,
    user_features=None,
    item_features=None,
    path=os.path.join(user_home, "models/knn.model")
).cache()
print(metrics.hit_rate_at_k(recs, test.cache(), 10))
print(metrics.recall_at_k(recs, test.cache(), 10))
print(metrics.precision_at_k(recs, test.cache(), 10))
print(metrics.ndcg_at_k(recs, test.cache(), 10))

17-Dec-19 11:17:02, root, DEBUG: Проверка датафреймов
17-Dec-19 11:17:03, root, DEBUG: Предварительная стадия обучения (pre-fit)
17-Dec-19 11:24:20, root, DEBUG: Основная стадия обучения (fit)
17-Dec-19 11:25:55, root, DEBUG: Проверка датафреймов


0.7713184739623624
0.15336453733420258
0.2702152271199182
0.31113177636939676
CPU times: user 212 ms, sys: 52 ms, total: 264 ms
Wall time: 10min 42s


In [10]:
test.select("user_id").distinct().count()

138319

In [13]:
%%time
from sponge_bob_magic.models.als_recommender import ALSRecommender

als_recommender = ALSRecommender(spark, rank=30)
recs = als_recommender.fit_predict(
    k=10,
    users=test.select("user_id").distinct().cache(),
    items=test.select("item_id").distinct().cache(),
    log=train.cache(),
    context=None,
    user_features=None,
    item_features=None,
    path=os.path.join(user_home, "models/als.model"),
    batch_size=20000
).cache()
print(metrics.hit_rate_at_k(recs, test.cache(), 10))
print(metrics.recall_at_k(recs, test.cache(), 10))
print(metrics.precision_at_k(recs, test.cache(), 10))
print(metrics.ndcg_at_k(recs, test.cache(), 10))

17-Dec-19 13:30:13, root, DEBUG: Проверка датафреймов
17-Dec-19 13:30:13, root, DEBUG: Предварительная стадия обучения (pre-fit)
17-Dec-19 13:30:16, root, DEBUG: Основная стадия обучения (fit)
17-Dec-19 13:30:16, root, DEBUG: Индексирование данных
17-Dec-19 13:30:16, root, DEBUG: Обучение модели
17-Dec-19 13:31:10, root, DEBUG: Проверка датафреймов
17-Dec-19 13:53:59, root, DEBUG: b''


0.6852565446540244
0.08908984724170102
0.12146270577433325
0.13526835142817267
CPU times: user 432 ms, sys: 136 ms, total: 568 ms
Wall time: 24min 7s


In [21]:
spark.stop()

In [22]:
!ls $HOME/models/als.model/recs.parquet/ | head

part-00000-134a30d0-2e4c-4413-96f3-d95015665c30-c000.snappy.parquet
part-00000-2a5687e3-46b8-4073-9461-8af2ea7e20d0-c000.snappy.parquet
part-00000-411d61af-5e15-4fb8-b5e3-82dee627978e-c000.snappy.parquet
part-00000-5123bc6f-2db8-462f-b206-5a1499d4143f-c000.snappy.parquet
part-00000-658f5169-d2b5-45c2-bc68-5424650c7b95-c000.snappy.parquet
part-00000-92c3a914-0b96-4e4e-857c-77156a4b0f8a-c000.snappy.parquet
part-00001-134a30d0-2e4c-4413-96f3-d95015665c30-c000.snappy.parquet
part-00001-2a5687e3-46b8-4073-9461-8af2ea7e20d0-c000.snappy.parquet
part-00001-411d61af-5e15-4fb8-b5e3-82dee627978e-c000.snappy.parquet
part-00001-5123bc6f-2db8-462f-b206-5a1499d4143f-c000.snappy.parquet
ls: write error: Broken pipe
