In [1]:
from typing import Dict, Tuple

import numpy as np
import sklearn
import sklearn.model_selection
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset, DataLoader
from tqdm.notebook import tqdm
import pandas as pd

from recommend.utils import PROJ_ROOT

In [2]:
ratings = pd.read_pickle(f"{PROJ_ROOT}/data/ratings.pkl")

In [3]:
ratings.shape

(2797751, 5)

In [4]:
def filter_ratings(ratings: pd.DataFrame, min_user_ratings: int, min_movie_ratings: int) -> pd.DataFrame:
    ratings = ratings.dropna()

    user_rating_count = ratings.groupby("username").size()
    worthy_users = set(user_rating_count[user_rating_count >= min_user_ratings].index)
    ratings = ratings.loc[ratings.username.isin(worthy_users)]

    movie_rating_count = ratings.groupby("movie_id").size()
    worthy_movies = set(movie_rating_count[movie_rating_count >= min_movie_ratings].index)
    ratings = ratings.loc[ratings.movie_id.isin(worthy_movies)]
    
    return ratings

In [5]:
ratings = filter_ratings(ratings, 10, 10)

In [6]:
ratings.shape

(2722362, 5)

In [7]:
ratings_train, ratings_test = sklearn.model_selection.train_test_split(
    ratings,
    test_size=0.1,
    shuffle=True,
    random_state=0,
)

ratings_train, ratings_valid = sklearn.model_selection.train_test_split(
    ratings_train,
    test_size=0.1,
    shuffle=True,
    random_state=0,
)

In [8]:
idx2movie = pd.Series(ratings.movie_id.unique())
movie2idx = pd.Series(idx2movie.index.values, index=idx2movie.values)
movie2idx.head()

230421-houbicky                       0
10789-prvni-liga                      1
235032-yes-man                        2
234754-chlapec-v-pruhovanem-pyzamu    3
301717-nejvetsi-showman               4
dtype: int64

In [9]:
idx2user = pd.Series(ratings.username.unique())
user2idx = pd.Series(idx2user.index.values, index=idx2user.values)
user2idx.head()


kinghome     0
SimonShot    1
blackend     2
LCH          3
knoxville    4
dtype: int64

In [10]:
random_gen = np.random.default_rng(seed=42)

num_features = 100

movies_shape = len(movie2idx), num_features
movies = torch.from_numpy(
    random_gen.standard_normal(movies_shape, dtype=np.float32).clip(-2, 2)
)

users_shape = len(user2idx), num_features
users = torch.from_numpy(
    random_gen.standard_normal(users_shape, dtype=np.float32).clip(-2, 2)
)




In [11]:
device = 'cuda'
movies = movies.to(device).requires_grad_()
users = users.to(device).requires_grad_()
None

In [12]:
ratings_train.head()

Unnamed: 0,username,movie_id,stars,date,comment
328863,pepo,2596-osm-a-pul,1.0,2018-04-24,"Súhlasím, že 8 1/2 je umelecký film. Je totiž..."
1601217,Melly.pro,224137-hvezdny-prach,5.0,2008-01-18,"Tak nejsem si úplně jistej, co o tom říct, pr..."
1369474,GOREGASM,223564-kazatel,2.0,2016-01-31,Béčkové vody zakalené do digitálního apokalyp...
499592,Ernie_13,308635-zpatky-ve-hre,3.0,2014-03-07,"No je to tých 70%....hrozná réžia, hrozný sce..."
2322162,Khumbac,71194-walker-texas-ranger,4.0,2012-07-14,"Z nostalgie, z recese, z úcty ke komedii dává..."


In [13]:
class RatingDataset(Dataset): 

    def __init__(self, df_ratings: pd.DataFrame, movie2idx: pd.Series, user2idx: pd.Series):
        self.df_ratings = df_ratings
        self.movie2idx = movie2idx
        self.user2idx = user2idx

    def __len__(self) -> int:
        return len(self.df_ratings)

    def __getitem__(self, idx: int) -> Tuple[int, int, float]:
        row = self.df_ratings.iloc[idx]
        idx_movie = self.movie2idx[row.movie_id]
        idx_user = self.user2idx[row.username]
        rating = (row.stars / 5.0).astype(np.float32)
        return idx_movie, idx_user, rating

In [14]:
ds_train = RatingDataset(ratings_train, movie2idx, user2idx)
ds_valid = RatingDataset(ratings_valid, movie2idx, user2idx)
ds_test = RatingDataset(ratings_test, movie2idx, user2idx)

ds_train_mini = Subset(
    ds_train,
    np.random.default_rng(0).choice(len(ds_train), len(ds_valid), replace=False)
)

In [15]:
len(ds_train), len(ds_valid), len(ds_test), len(ds_train_mini)

(2205112, 245013, 272237, 245013)

In [16]:
ds_train[0]

(1023, 2359, 0.2)

In [17]:
batch_size = 64

def collate_move(device):
    def collate(*params):
        movie_idx, user_idx, rating = torch.utils.data.default_collate(*params)
        return movie_idx.to(device), user_idx.to(device), rating.to(device)
    return collate


loader_train = DataLoader(ds_train, batch_size, collate_fn=collate_move(device))
loader_valid = DataLoader(ds_valid, batch_size, collate_fn=collate_move(device))
loader_test = DataLoader(ds_test, batch_size, collate_fn=collate_move(device))

loader_train_mini = DataLoader(ds_train_mini, batch_size, collate_fn=collate_move(device))

In [18]:
len(loader_train), len(loader_valid), len(loader_test), len(loader_train_mini)

(34455, 3829, 4254, 3829)

In [19]:
optim = torch.optim.Adam(
    params=[movies, users],
    lr=1e-2,
)

In [23]:
class Trainer:

    def __init__(self,
        users: torch.Tensor,
        movies: torch.Tensor,
        optimizer: torch.optim.Optimizer,
    ):
        self.users = users
        self.movies = movies
        self.metrics_train = []
        self.metrics_valid = []
        self.past_movies = []
        self.past_users = []
        self.optimizer = optimizer


    def evaluate(
        self,
        loader: DataLoader,
        tqdm_desc: str,
        tqdm_leave: bool,
    ) -> Dict[str, float]:

        metrics = {
            "mse": 0.0,
            "mae_stars": 0.0
        }

        with torch.no_grad():
            for movie_idx, user_idx, rating in tqdm(loader, desc=tqdm_desc, leave=tqdm_leave, position=1):
                m = self.movies[movie_idx]
                u = self.users[user_idx]
                pred = torch.sigmoid((m * u).sum(-1))
                metrics["mse"] += F.mse_loss(pred, rating).cpu().item()
                metrics["mae_stars"] += F.l1_loss(pred * 5, rating * 5).cpu().item()
        return {name: value / len(loader) for name, value in metrics.items()}

    
    def train(
        self,
        loader_train: DataLoader,
        loader_train_mini: DataLoader,
        loader_valid: DataLoader,
        patience: int = 5,
        max_epochs: int = 25,
        evaluate_every_n_steps: int = 12500,
        early_stopping_metric: str = "mae_stars",
    ):
        step = 0
        for epoch in range(max_epochs):
            for movie_idx, user_idx, rating in tqdm(loader_train, desc=f"epoch {epoch}", position=0):
                m = self.movies[movie_idx]
                u = self.users[user_idx]
                pred = torch.sigmoid((m * u).sum(-1))
                loss = F.mse_loss(pred, rating)
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                if step % evaluate_every_n_steps == 0:
                    self.metrics_train.append(self.evaluate(loader_train_mini, "train evaluation", tqdm_leave=False))
                    self.metrics_valid.append(self.evaluate(loader_valid, "validation evaluation", tqdm_leave=False))
                    print("train metrics:     ", self.metrics_train[-1])
                    print("validation metrics:", self.metrics_valid[-1])
                    self.past_movies.append(movies.clone().detach().cpu())
                    self.past_users.append(users.clone().detach().cpu())
                    if len(self.metrics_train) >= patience:
                        self.past_movies.pop(0)
                        self.past_users.pop(0)
                        last_n = pd.DataFrame(self.metrics_train[-patience:])[early_stopping_metric]
                        if last_n.argmin() == 0:
                            # TODO revert back to best weights
                            return
                step += 1


In [24]:
trainer = Trainer(users, movies, optim)

In [25]:
trainer.train(loader_train, loader_train_mini, loader_valid)

epoch 0:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.31217285436542824, 'mae_stars': 2.318140492360814}
validation metrics: {'mse': 0.31177442685668605, 'mae_stars': 2.3166307150212364}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.28783912754737545, 'mae_stars': 2.181232829323074}
validation metrics: {'mse': 0.30579391327733063, 'mae_stars': 2.274985332375689}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.2076809807413239, 'mae_stars': 1.7699378756575188}
validation metrics: {'mse': 0.22513355742814056, 'mae_stars': 1.86108544331115}


epoch 1:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.17533279422619039, 'mae_stars': 1.6047005335199331}
validation metrics: {'mse': 0.1919462075644263, 'mae_stars': 1.6911091848402304}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.16515964421218168, 'mae_stars': 1.5527619720906307}
validation metrics: {'mse': 0.18087961601196054, 'mae_stars': 1.6338155698016594}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1589812959655516, 'mae_stars': 1.5214867753128247}
validation metrics: {'mse': 0.1741661970559814, 'mae_stars': 1.5995021568760606}


epoch 2:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1558427723670915, 'mae_stars': 1.5057864524297597}
validation metrics: {'mse': 0.17030467672848024, 'mae_stars': 1.5802035616595491}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.15387444389212004, 'mae_stars': 1.4957386429454926}
validation metrics: {'mse': 0.16825371385437915, 'mae_stars': 1.5692810291532873}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.15155141286359317, 'mae_stars': 1.484190226675294}
validation metrics: {'mse': 0.16597221756913888, 'mae_stars': 1.557782025974206}


epoch 3:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1506289272000208, 'mae_stars': 1.4793216449904922}
validation metrics: {'mse': 0.1657295861410609, 'mae_stars': 1.5565647389005202}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14967359298123614, 'mae_stars': 1.4746654259688425}
validation metrics: {'mse': 0.16521901104742256, 'mae_stars': 1.5538258013945705}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14872468364872302, 'mae_stars': 1.4696809217960352}
validation metrics: {'mse': 0.16387339902043996, 'mae_stars': 1.5471875996339348}


epoch 4:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14790128427125918, 'mae_stars': 1.4656234376889417}
validation metrics: {'mse': 0.16300952418778578, 'mae_stars': 1.5426557613247274}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14709597173567446, 'mae_stars': 1.4615804719594914}
validation metrics: {'mse': 0.1621114492268499, 'mae_stars': 1.538137581959254}


epoch 5:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14631391719049808, 'mae_stars': 1.457781948847757}
validation metrics: {'mse': 0.16111080891378057, 'mae_stars': 1.5331772393986027}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14595851835693321, 'mae_stars': 1.4560409666788923}
validation metrics: {'mse': 0.16074813163741258, 'mae_stars': 1.53123655747919}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14538384422943965, 'mae_stars': 1.4530473195500895}
validation metrics: {'mse': 0.16091681007158665, 'mae_stars': 1.5320382666793309}


epoch 6:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14488364901655365, 'mae_stars': 1.4504647131344768}
validation metrics: {'mse': 0.16025333324839847, 'mae_stars': 1.5285688827354749}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14451378809818868, 'mae_stars': 1.4488543759254102}
validation metrics: {'mse': 0.1603950688420372, 'mae_stars': 1.5295591828529813}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14374118084056348, 'mae_stars': 1.4447723047544638}
validation metrics: {'mse': 0.15971026940869115, 'mae_stars': 1.525840607206888}


epoch 7:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14337176310818386, 'mae_stars': 1.4431463942636116}
validation metrics: {'mse': 0.15967409271668068, 'mae_stars': 1.5259095786227417}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14309018379435834, 'mae_stars': 1.4415360141528994}
validation metrics: {'mse': 0.15912226909179544, 'mae_stars': 1.5230878723542538}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14265399666629064, 'mae_stars': 1.4393181927399026}
validation metrics: {'mse': 0.15869647691590133, 'mae_stars': 1.5207014189404089}


epoch 8:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14274084332189163, 'mae_stars': 1.4398974312849286}
validation metrics: {'mse': 0.15882714821639152, 'mae_stars': 1.5216052194088736}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1427196379489253, 'mae_stars': 1.4397028263071618}
validation metrics: {'mse': 0.15864260572542188, 'mae_stars': 1.5204837664078408}


epoch 9:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1424189915554148, 'mae_stars': 1.438287384393602}
validation metrics: {'mse': 0.15825187841604946, 'mae_stars': 1.5186327556836507}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14236009312508063, 'mae_stars': 1.4378867493531324}
validation metrics: {'mse': 0.1580830452534168, 'mae_stars': 1.5176972515995866}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14201847423178374, 'mae_stars': 1.4363601311718275}
validation metrics: {'mse': 0.1579849755026149, 'mae_stars': 1.5173416603449026}


epoch 10:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14174339805770306, 'mae_stars': 1.4347998622535383}
validation metrics: {'mse': 0.15813174850987716, 'mae_stars': 1.5180240636645237}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14172006992598402, 'mae_stars': 1.4346863300690111}
validation metrics: {'mse': 0.1576237165334798, 'mae_stars': 1.5154449865007935}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14158218227897, 'mae_stars': 1.4340572842701076}
validation metrics: {'mse': 0.1575882539803006, 'mae_stars': 1.5153546163325335}


epoch 11:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14161764579922637, 'mae_stars': 1.4343500540971943}
validation metrics: {'mse': 0.1577254984324487, 'mae_stars': 1.5160420394871559}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1413510885224871, 'mae_stars': 1.4328586597702901}
validation metrics: {'mse': 0.1576704309830531, 'mae_stars': 1.5157980249830312}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14120620957287025, 'mae_stars': 1.4322071564387453}
validation metrics: {'mse': 0.15731973552453235, 'mae_stars': 1.5139133765511301}


epoch 12:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14097671852694943, 'mae_stars': 1.430944591628319}
validation metrics: {'mse': 0.157144139933583, 'mae_stars': 1.51295463495262}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14092996982105035, 'mae_stars': 1.4308830258366207}
validation metrics: {'mse': 0.1570930261058856, 'mae_stars': 1.5127112838362053}


epoch 13:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1407275581136482, 'mae_stars': 1.4298348048705287}
validation metrics: {'mse': 0.15712386640370063, 'mae_stars': 1.512989656300668}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1407260887050448, 'mae_stars': 1.4297425312141614}
validation metrics: {'mse': 0.15664965148016005, 'mae_stars': 1.5105141028314324}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1403438879790397, 'mae_stars': 1.4277892328329826}
validation metrics: {'mse': 0.15652485119826365, 'mae_stars': 1.5097750834429908}


epoch 14:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1402822542876262, 'mae_stars': 1.4275980683150133}
validation metrics: {'mse': 0.15617172518289255, 'mae_stars': 1.5078617400368377}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1401793563212605, 'mae_stars': 1.4270840807445493}
validation metrics: {'mse': 0.1561014299453877, 'mae_stars': 1.507774469777105}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.14001477915442434, 'mae_stars': 1.4262537919033536}
validation metrics: {'mse': 0.15585060307712773, 'mae_stars': 1.5064595982560791}


epoch 15:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13995763791015298, 'mae_stars': 1.4259584717789155}
validation metrics: {'mse': 0.15592386650997095, 'mae_stars': 1.506734778750936}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13997243726750086, 'mae_stars': 1.4259617535680043}
validation metrics: {'mse': 0.15597518062772125, 'mae_stars': 1.5070847819884472}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13998009575218096, 'mae_stars': 1.4261238140049912}
validation metrics: {'mse': 0.15564634355957932, 'mae_stars': 1.5052571175116476}


epoch 16:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13965341856629707, 'mae_stars': 1.4244337170483332}
validation metrics: {'mse': 0.15547454345083075, 'mae_stars': 1.5044870357906952}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13958943528204681, 'mae_stars': 1.4241516376855512}
validation metrics: {'mse': 0.15563870681954228, 'mae_stars': 1.5051248304586742}


epoch 17:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1394254940200269, 'mae_stars': 1.4233101167545583}
validation metrics: {'mse': 0.15558281062391227, 'mae_stars': 1.5049288897079722}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13930274718546873, 'mae_stars': 1.4227134057515967}
validation metrics: {'mse': 0.1553977464560928, 'mae_stars': 1.50411100829977}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1391560108629273, 'mae_stars': 1.4219556349660685}
validation metrics: {'mse': 0.15489630984494132, 'mae_stars': 1.5015915053480442}


epoch 18:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13914407709149848, 'mae_stars': 1.4219253501714006}
validation metrics: {'mse': 0.15509890140535065, 'mae_stars': 1.50267508559826}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1389895395421272, 'mae_stars': 1.4210873159175443}
validation metrics: {'mse': 0.15485207794672454, 'mae_stars': 1.5014314825212902}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13874296981770304, 'mae_stars': 1.4198097543979196}
validation metrics: {'mse': 0.15467392293248422, 'mae_stars': 1.5005623908228463}


epoch 19:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13875757116940166, 'mae_stars': 1.4199848342234487}
validation metrics: {'mse': 0.15453293221399722, 'mae_stars': 1.499675622619017}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13876630458664602, 'mae_stars': 1.4200189663741818}
validation metrics: {'mse': 0.15435756658310676, 'mae_stars': 1.4989952761655565}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1386679122367955, 'mae_stars': 1.4195172386183132}
validation metrics: {'mse': 0.15449854465370297, 'mae_stars': 1.4997007661122321}


epoch 20:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13865190016854428, 'mae_stars': 1.4193910264912832}
validation metrics: {'mse': 0.15461904273803068, 'mae_stars': 1.5001758357101833}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1385024419255296, 'mae_stars': 1.418700791410878}
validation metrics: {'mse': 0.15460062455059997, 'mae_stars': 1.5002335739621422}


epoch 21:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13826603557720885, 'mae_stars': 1.4174896826100243}
validation metrics: {'mse': 0.15436329487511743, 'mae_stars': 1.4989571808585613}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.1382390458275924, 'mae_stars': 1.4173858926793743}
validation metrics: {'mse': 0.15447654232448932, 'mae_stars': 1.4994970257147138}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13823232285777623, 'mae_stars': 1.4173433857725009}
validation metrics: {'mse': 0.15452617712729466, 'mae_stars': 1.4996521854313463}


epoch 22:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13803239270741183, 'mae_stars': 1.4162919229924382}
validation metrics: {'mse': 0.15444946788556896, 'mae_stars': 1.4994309315578096}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13805549586039512, 'mae_stars': 1.416490027458682}
validation metrics: {'mse': 0.15430024042667187, 'mae_stars': 1.4985654474736878}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13804132189162877, 'mae_stars': 1.416449916272476}
validation metrics: {'mse': 0.15403641229527026, 'mae_stars': 1.497183043142933}


epoch 23:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13777997508084658, 'mae_stars': 1.4150549250565612}
validation metrics: {'mse': 0.15420797718632903, 'mae_stars': 1.498262297798928}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13773588655522828, 'mae_stars': 1.414882451179103}
validation metrics: {'mse': 0.15396355212777738, 'mae_stars': 1.4969376468042008}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13756721240382463, 'mae_stars': 1.414006270263923}
validation metrics: {'mse': 0.15398795840657079, 'mae_stars': 1.4971467034565809}


epoch 24:   0%|          | 0/34455 [00:00<?, ?it/s]

train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13752928195722142, 'mae_stars': 1.4138111741432604}
validation metrics: {'mse': 0.15384165537547242, 'mae_stars': 1.4963939515876221}


train evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

validation evaluation:   0%|          | 0/3829 [00:00<?, ?it/s]

train metrics:      {'mse': 0.13745612807554375, 'mae_stars': 1.4134965944178082}
validation metrics: {'mse': 0.15380797337679367, 'mae_stars': 1.496082435126839}
