In [68]:
import copy
from typing import Dict, Tuple

import numpy as np
import pandas as pd
import sklearn
import sklearn.model_selection
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader
from tqdm.notebook import tqdm

from recommend.utils import PROJ_ROOT

In [8]:
ratings = pd.read_pickle(f"{PROJ_ROOT}/data/ratings.pkl")

In [9]:
ratings.shape

(2797751, 5)

In [10]:
def filter_ratings(ratings: pd.DataFrame, min_user_ratings: int, min_movie_ratings: int) -> pd.DataFrame:
    ratings = ratings.dropna()

    user_rating_count = ratings.groupby("username").size()
    worthy_users = set(user_rating_count[user_rating_count >= min_user_ratings].index)
    ratings = ratings.loc[ratings.username.isin(worthy_users)]

    movie_rating_count = ratings.groupby("movie_id").size()
    worthy_movies = set(movie_rating_count[movie_rating_count >= min_movie_ratings].index)
    ratings = ratings.loc[ratings.movie_id.isin(worthy_movies)]
    
    return ratings

In [11]:
ratings = filter_ratings(ratings, 10, 10)

In [12]:
ratings.shape

(2722362, 5)

In [13]:
ratings_train, ratings_test = sklearn.model_selection.train_test_split(
    ratings,
    test_size=0.1,
    shuffle=True,
    random_state=0,
)

ratings_train, ratings_valid = sklearn.model_selection.train_test_split(
    ratings_train,
    test_size=0.1,
    shuffle=True,
    random_state=0,
)

In [14]:
idx2movie = pd.Series(ratings.movie_id.unique())
movie2idx = pd.Series(idx2movie.index.values, index=idx2movie.values)
movie2idx.head()

230421-houbicky                       0
10789-prvni-liga                      1
235032-yes-man                        2
234754-chlapec-v-pruhovanem-pyzamu    3
301717-nejvetsi-showman               4
dtype: int64

In [15]:
idx2user = pd.Series(ratings.username.unique())
user2idx = pd.Series(idx2user.index.values, index=idx2user.values)
user2idx.head()


kinghome     0
SimonShot    1
blackend     2
LCH          3
knoxville    4
dtype: int64

In [80]:
class Model(torch.nn.Module):

    def __init__(self, num_features: int, num_movies: int, num_users: int, seed: int = 42):
        super().__init__()
        random_gen = np.random.default_rng(seed)

        self.movies_weights = torch.nn.Embedding.from_pretrained(torch.from_numpy(
            random_gen.standard_normal((num_movies, num_features), dtype=np.float32).clip(-2, 2)
        ))

        self.users_weights = torch.nn.Embedding.from_pretrained(torch.from_numpy(
            random_gen.standard_normal((num_users, num_features), dtype=np.float32).clip(-2, 2)
        ))

        self.movies_biases = torch.nn.Parameter(torch.from_numpy(
            random_gen.standard_normal((num_movies,), dtype=np.float32).clip(-2, 2)
        ))

        self.users_biases = torch.nn.Parameter(torch.from_numpy(
            random_gen.standard_normal((num_users,), dtype=np.float32).clip(-2, 2)
        ))
        

    def forward(self, movie_idx, user_idx) -> torch.Tensor:
        m_w = self.movies_weights(movie_idx)
        m_b = self.movies_biases[movie_idx]
        u_w = self.users_weights(user_idx)
        u_b = self.users_biases[user_idx]
        return torch.sigmoid((m_w * u_w).sum(-1) + m_b + u_b)

In [99]:
model = Model(
    num_features=100,
    num_movies=len(movie2idx),
    num_users=len(user2idx)
)

In [100]:
device = 'cuda'
model = model.to(device)

In [101]:
ratings_train.head()

Unnamed: 0,username,movie_id,stars,date,comment
328863,pepo,2596-osm-a-pul,1.0,2018-04-24,"Súhlasím, že 8 1/2 je umelecký film. Je totiž..."
1601217,Melly.pro,224137-hvezdny-prach,5.0,2008-01-18,"Tak nejsem si úplně jistej, co o tom říct, pr..."
1369474,GOREGASM,223564-kazatel,2.0,2016-01-31,Béčkové vody zakalené do digitálního apokalyp...
499592,Ernie_13,308635-zpatky-ve-hre,3.0,2014-03-07,"No je to tých 70%....hrozná réžia, hrozný sce..."
2322162,Khumbac,71194-walker-texas-ranger,4.0,2012-07-14,"Z nostalgie, z recese, z úcty ke komedii dává..."


In [102]:
class RatingDataset(Dataset): 

    def __init__(self, df_ratings: pd.DataFrame, movie2idx: pd.Series, user2idx: pd.Series):
        self.df_ratings = df_ratings
        self.movie2idx = movie2idx
        self.user2idx = user2idx

    def __len__(self) -> int:
        return len(self.df_ratings)

    def __getitem__(self, idx: int) -> Tuple[int, int, float]:
        row = self.df_ratings.iloc[idx]
        idx_movie = self.movie2idx[row.movie_id]
        idx_user = self.user2idx[row.username]
        rating = (row.stars / 5.0).astype(np.float32)
        return idx_movie, idx_user, rating

In [103]:
ds_train = RatingDataset(ratings_train, movie2idx, user2idx)
ds_valid = RatingDataset(ratings_valid, movie2idx, user2idx)
ds_test = RatingDataset(ratings_test, movie2idx, user2idx)

ds_train_mini = Subset(
    ds_train,
    np.random.default_rng(0).choice(len(ds_train), len(ds_valid), replace=False)
)

In [104]:
len(ds_train), len(ds_valid), len(ds_test), len(ds_train_mini)

(2205112, 245013, 272237, 245013)

In [105]:
ds_train[0]

(1023, 2359, 0.2)

In [106]:
batch_size = 64

def collate_move(device):
    def collate(*params):
        movie_idx, user_idx, rating = torch.utils.data.default_collate(*params)
        return movie_idx.to(device), user_idx.to(device), rating.to(device)
    return collate


loader_train = DataLoader(ds_train, batch_size, collate_fn=collate_move(device))
loader_valid = DataLoader(ds_valid, batch_size, collate_fn=collate_move(device))
loader_test = DataLoader(ds_test, batch_size, collate_fn=collate_move(device))

loader_train_mini = DataLoader(ds_train_mini, batch_size, collate_fn=collate_move(device))

In [107]:
len(loader_train), len(loader_valid), len(loader_test), len(loader_train_mini)

(34455, 3829, 4254, 3829)

In [108]:
optim = torch.optim.Adam(
    params=model.parameters(),
    lr=1e-1,
)

In [114]:
class Trainer:

    def __init__(
        self,
        model: Model,
        optimizer: torch.optim.Optimizer,
    ):
        self.model = model
        self.metrics_train = []
        self.metrics_valid = []
        self.past_models = []
        self.optimizer = optimizer


    def evaluate(
        self,
        loader: DataLoader,
        tqdm_desc: str,
        tqdm_leave: bool,
    ) -> Dict[str, float]:

        metrics = {
            "mse": 0.0,
            "mae_stars": 0.0
        }

        with torch.no_grad():
            for movie_idx, user_idx, rating in tqdm(loader, desc=tqdm_desc, leave=tqdm_leave, position=1):
                pred = model(movie_idx, user_idx)
                metrics["mse"] += F.mse_loss(pred, rating).cpu().item()
                metrics["mae_stars"] += F.l1_loss(pred * 5, rating * 5).cpu().item()
        return {name: value / len(loader) for name, value in metrics.items()}

    
    def train(
        self,
        loader_train: DataLoader,
        loader_train_mini: DataLoader,
        loader_valid: DataLoader,
        patience: int = 5,
        max_epochs: int = 25,
        evaluate_every_n_steps: int = 12500,
        early_stopping_metric: str = "mae_stars",
    ):
        step = 0
        for epoch in range(max_epochs):
            for movie_idx, user_idx, rating in tqdm(loader_train, desc=f"epoch {epoch}", position=0):
                pred = model(movie_idx, user_idx)
                F.mse_loss(pred, rating).backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if step % evaluate_every_n_steps == 0:
                    self.metrics_train.append(self.evaluate(loader_train_mini, "train evaluation", tqdm_leave=False))
                    self.metrics_valid.append(self.evaluate(loader_valid, "validation evaluation", tqdm_leave=False))
                    print("train metrics:     ", self.metrics_train[-1])
                    print("validation metrics:", self.metrics_valid[-1])
                    self.past_models.append(copy.deepcopy(model).cpu())
                    if len(self.metrics_train) >= patience:
                        self.past_models.pop(0)
                        last_n = pd.DataFrame(self.metrics_train[-patience:])[early_stopping_metric]
                        if last_n.argmin() == 0:
                            device = list(model.parameters())[0].device
                            self.model = self.past_models[0].to(device)
                            return self.model
                step += 1
        return self.model


In [110]:
trainer = Trainer(model, optim)

In [111]:
trainer.train(loader_train, loader_train_mini, loader_valid)

epoch 0:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.31293469379841116, 'mae_stars': 2.321519985339697}
validation metrics: {'mse': 0.31249594485270943, 'mae_stars': 2.3195831577035024}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1469001427631398, 'mae_stars': 1.4594384279597301}
validation metrics: {'mse': 0.14913660194262665, 'mae_stars': 1.472096759412456}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14004322591868415, 'mae_stars': 1.4235767961855998}
validation metrics: {'mse': 0.1433714321206649, 'mae_stars': 1.4420375617561032}


epoch 1:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1376422492392725, 'mae_stars': 1.4112411741022342}
validation metrics: {'mse': 0.14132866200150132, 'mae_stars': 1.4312973939024396}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13615418705952076, 'mae_stars': 1.4037754787076202}
validation metrics: {'mse': 0.1402191033202072, 'mae_stars': 1.4256003709958969}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13566375462511993, 'mae_stars': 1.4008644891177497}
validation metrics: {'mse': 0.14003643939516258, 'mae_stars': 1.4247534457807225}


epoch 2:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13486892812014342, 'mae_stars': 1.396647390450034}
validation metrics: {'mse': 0.1389391231444244, 'mae_stars': 1.4189455155200514}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13443621128532898, 'mae_stars': 1.3944637374424629}
validation metrics: {'mse': 0.13886985373648006, 'mae_stars': 1.4184206273045792}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1339803738210649, 'mae_stars': 1.3923197326538115}
validation metrics: {'mse': 0.13828166647402915, 'mae_stars': 1.4153577600206648}


epoch 3:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1338594939825304, 'mae_stars': 1.3918715425557584}
validation metrics: {'mse': 0.13812339570315896, 'mae_stars': 1.4145074111109583}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13335103058560488, 'mae_stars': 1.3891987165168111}
validation metrics: {'mse': 0.1379327409496722, 'mae_stars': 1.413561559869154}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13296757327510905, 'mae_stars': 1.3870875698120981}
validation metrics: {'mse': 0.13789059922243105, 'mae_stars': 1.4133691378161564}


epoch 4:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1328107949320035, 'mae_stars': 1.3860466380161665}
validation metrics: {'mse': 0.13743950381629863, 'mae_stars': 1.4110892740233694}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1326264979074069, 'mae_stars': 1.3855122486895572}
validation metrics: {'mse': 0.13713748444267712, 'mae_stars': 1.409738239948729}


epoch 5:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13241319688750672, 'mae_stars': 1.3841073304905098}
validation metrics: {'mse': 0.1370571072464607, 'mae_stars': 1.4094177310788685}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13236457747683322, 'mae_stars': 1.3839328849586503}
validation metrics: {'mse': 0.1369808052551438, 'mae_stars': 1.4087926828384898}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13258457713881808, 'mae_stars': 1.3851385639304745}
validation metrics: {'mse': 0.13755290748268856, 'mae_stars': 1.4118866499860718}


epoch 6:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13243107687021305, 'mae_stars': 1.384488394784068}
validation metrics: {'mse': 0.13728658580988623, 'mae_stars': 1.4101795919354325}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1326270870870229, 'mae_stars': 1.3853622141784774}
validation metrics: {'mse': 0.1374406270929932, 'mae_stars': 1.411262867873504}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13230670554445195, 'mae_stars': 1.3837584925557154}
validation metrics: {'mse': 0.13740053390181697, 'mae_stars': 1.4110282857382568}


epoch 7:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13210678530614225, 'mae_stars': 1.3826548968353978}
validation metrics: {'mse': 0.13700368796511958, 'mae_stars': 1.4087902513923953}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13206383492140897, 'mae_stars': 1.3826336347113455}
validation metrics: {'mse': 0.13691346000779792, 'mae_stars': 1.4086978708399467}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13199305048683785, 'mae_stars': 1.3824357317707563}
validation metrics: {'mse': 0.13647035330331614, 'mae_stars': 1.4064547978234558}


epoch 8:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13154346722945062, 'mae_stars': 1.380117439529422}
validation metrics: {'mse': 0.13631329930234934, 'mae_stars': 1.4055721253413511}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13169107487686435, 'mae_stars': 1.3807538291218076}
validation metrics: {'mse': 0.1365590615698581, 'mae_stars': 1.4067478406893177}


epoch 9:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13163217275042732, 'mae_stars': 1.3805716379593544}
validation metrics: {'mse': 0.13643643405355588, 'mae_stars': 1.4062886592706085}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13157790586776538, 'mae_stars': 1.3800329198317978}
validation metrics: {'mse': 0.1359961824100022, 'mae_stars': 1.4039188185694573}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13156927269964935, 'mae_stars': 1.3802196153090063}
validation metrics: {'mse': 0.13627865587911583, 'mae_stars': 1.4054418853164248}


epoch 10:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13150855987781518, 'mae_stars': 1.380005461751216}
validation metrics: {'mse': 0.13626010397234609, 'mae_stars': 1.4050617051610252}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13138822008670545, 'mae_stars': 1.37894989140322}
validation metrics: {'mse': 0.13600253849855817, 'mae_stars': 1.4036825210377515}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1315935746625311, 'mae_stars': 1.3802220279973056}
validation metrics: {'mse': 0.13608013028480825, 'mae_stars': 1.404575894725139}


epoch 11:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13119663598176923, 'mae_stars': 1.3783125269347614}
validation metrics: {'mse': 0.13615182967558329, 'mae_stars': 1.404764423049066}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1310456905248396, 'mae_stars': 1.3773348581451572}
validation metrics: {'mse': 0.13592685897416001, 'mae_stars': 1.403352489823864}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1315063276457301, 'mae_stars': 1.3797840281920135}
validation metrics: {'mse': 0.13611126986869043, 'mae_stars': 1.4047661539810263}


epoch 12:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13122184861030114, 'mae_stars': 1.3781003946889379}
validation metrics: {'mse': 0.13601310313969936, 'mae_stars': 1.4039233851351853}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1312541563222445, 'mae_stars': 1.3787118163174827}
validation metrics: {'mse': 0.13639763811178512, 'mae_stars': 1.4062514528184624}


epoch 13:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1314029836680348, 'mae_stars': 1.3792965188758683}
validation metrics: {'mse': 0.13657375377232925, 'mae_stars': 1.4070892150142418}


Model(
  (movies_weights): Embedding(8670, 100)
  (users_weights): Embedding(22222, 100)
)