In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sponge_bob_magic.session_handler import State

spark = State().session
spark

In [3]:
from sponge_bob_magic.datasets import MovieLens

data = MovieLens("1m")
data.info()

ratings


Unnamed: 0,user_id,item_id,relevance,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968



users


Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117



items


Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance





In [4]:
from sponge_bob_magic.data_preparator import DataPreparator

log = DataPreparator().transform(
    data=data.ratings,
    columns_names={
        "user_id": "user_id",
        "item_id": "item_id",
        "relevance": "relevance",
        "timestamp": "timestamp"
    }
)

In [5]:
from sponge_bob_magic.splitters import UserSplitter

splitter = UserSplitter(
    drop_cold_items=True,
    drop_cold_users=True,
    item_test_size=1,
    user_test_size=1000,
    seed=1234,
    shuffle=True
)
train, test = splitter.split(log)
(
    train.count(), 
    test.count()
)

(999209, 1000)

In [6]:
from sponge_bob_magic.metrics import HitRate, NDCG
from sponge_bob_magic.experiment import Experiment

metrics = Experiment(
    test, {NDCG(): 10, 
           HitRate(): 10}
)



In [7]:
%%time
from sponge_bob_magic.models import PopRec

metrics.add_result(
    "PopRec",
    PopRec().fit_predict(
        log=train,
        k=10,
        users=test.select("user_id").distinct()
    )
)

CPU times: user 245 ms, sys: 72.7 ms, total: 317 ms
Wall time: 12min 44s


In [10]:
%%time
from sponge_bob_magic.models import RandomPop

for i in range(-9, 10):
    alpha = i / 10
    metrics.add_result(
        f"RandomPop(alpha={alpha})",
        RandomPop(alpha).fit_predict(
            log=train,
            k=10,
            users=test.select("user_id").distinct()
        )
    )

CPU times: user 10.4 s, sys: 2.48 s, total: 12.9 s
Wall time: 36min 38s


In [15]:
metrics.pandas_df.sort_values("HitRate@10", ascending=False)

Unnamed: 0,HitRate@10,NDCG@10
PopRec,0.08,0.041412
RandomPop(alpha=-0.9),0.017,0.004409
RandomPop(alpha=0.5),0.017,0.00548
RandomPop(alpha=-0.7),0.015,0.004907
RandomPop(alpha=-0.3),0.015,0.007387
RandomPop(alpha=0.8),0.013,0.003524
RandomPop(alpha=0.1),0.012,0.005707
RandomPop(alpha=0.7),0.012,0.006054
RandomPop(alpha=0.4),0.012,0.005891
RandomPop(alpha=0.9),0.012,0.006558
