In [53]:
import pandas as pd

from rs_datasets import MovieLens

from replay.splitters import (
    TimeSplitter,
    LastNSplitter,
    NewUsersSplitter
)
from replay.preprocessing.filters import (
    InteractionEntriesFilter
)

In [13]:
ml = MovieLens("1m")  
ratings = ml.ratings      

query_column = "user_id"
item_column = "item_id"

In [None]:
time_splitter = TimeSplitter(
    time_threshold=0.1,
    query_column=query_column,
    item_column=item_column
)

interaction_filter = InteractionEntriesFilter(
    min_inter_per_user=1,
    query_column=query_column,
    item_column=item_column
)

loo_splitter = LastNSplitter(
    N=1,
    divide_column=query_column,
    query_column=query_column,
    item_column=item_column
)

In [None]:
# create test
train_val, test_holdout = time_splitter.split(ratings)

# create validation subset
train, val_holdout = time_splitter.split(train_val)

In [45]:
# remove short sequences from train (>1)
train = interaction_filter.transform(train)

# remove cold items from test_holdout, val_holdout based on filtered training
val_holdout = val_holdout[val_holdout[item_column].isin(train[item_column].unique())]
test_holdout = test_holdout[test_holdout[item_column].isin(train[item_column].unique())]

# remove cold items from train+val (optional) based on filtered training
train_val = train_val[train_val[item_column].isin(train[item_column].unique())]

In [None]:
# create val target and input
val_input, val_target = loo_splitter.split(val_holdout)

# combine val input
val_input = pd.concat([train, val_input], axis=0)

# remove targets with no input (filter cold users)
val_target = val_target[val_target[query_column].isin(val_input[query_column].unique())]

# filter: 5000 users in validation subset
if val_target[query_column].nunique() > 5000:
    new_users_splitter = NewUsersSplitter(
        test_size=5000/val_target[query_column].nunique(),
        query_column=query_column,
        item_column=item_column,
    )
    _, val_target = new_users_splitter.split(val_target)

In [65]:
# create test target and input
test_input, test_target = loo_splitter.split(test_holdout)

# combine test input
test_input = pd.concat([train, test_input], axis=0)

# remove targets with no input (filter cold users)
test_target = test_target[test_target[query_column].isin(test_input[query_column].unique())]

# filter: 5000 users in validation subset
if test_target[query_column].nunique() > 10000:
    new_users_splitter = NewUsersSplitter(
        test_size=10000/test_target[query_column].nunique(),
        query_column=query_column,
        item_column=item_column,
    )
    _, test_target = new_users_splitter.split(test_target)