In [496]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import TensorDataset

from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np

from datetime import datetime
from tqdm import tqdm

import random
from pathlib import Path

from sklearn.model_selection import train_test_split
from itertools import chain

from sklearn.metrics import roc_auc_score
from tqdm.auto import tqdm, trange

In [497]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Tylko do debugging

print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
print("Current device:", torch.cuda.current_device(), torch.cuda.get_device_name(0))

CUDA available: True
Number of GPUs: 1
Current device: 0 NVIDIA GeForce RTX 3060 Ti


In [498]:
BASE_DIR = Path(os.getcwd()).parent
DATA_DIR = BASE_DIR / "data"
df_users = pd.read_parquet(DATA_DIR / 'user_features_clean.parquet')
df_movies = pd.read_parquet(DATA_DIR / 'Movies_clean_Vec_v4_25keywords.parquet')
df_ratings = pd.read_parquet(DATA_DIR / 'ratings_groupped_ids.parquet')

In [499]:
df_movies.info()
df_ratings.info()
df_users.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 82918 entries, 0 to 82917
Data columns (total 29 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   movieId              82918 non-null  int64  
 1   runtime              82918 non-null  float64
 2   if_blockbuster       82918 non-null  int32  
 3   highly_watched       82918 non-null  int32  
 4   highly_rated         82918 non-null  int64  
 5   engagement_score     82918 non-null  float64
 6   cast_importance      82918 non-null  float64
 7   director_score       82918 non-null  float64
 8   has_keywords         82918 non-null  int64  
 9   has_cast             82918 non-null  int64  
 10  has_director         82918 non-null  int64  
 11  genre_ids            82918 non-null  object 
 12  decade_[1890, 1900)  82918 non-null  bool   
 13  decade_[1900, 1910)  82918 non-null  bool   
 14  decade_[1910, 1920)  82918 non-null  bool   
 15  decade_[1920, 1930)  82918 non-null 

In [500]:
# uids_users   = set(df_users['userId'])
# uids_ratings = set(df_ratings['userId'])
# assert uids_ratings <= uids_users, (
#     f"Brakujące userId w df_users: {uids_ratings - uids_users}"
# )
#
# mids_pos = set(x for lst in df_ratings['pos'] for x in lst)
# mids_neg = set(x for lst in df_ratings['neg'] for x in lst)
# mids_ratings = mids_pos | mids_neg
# mids_movies  = set(df_movies['movieId'])
# assert mids_ratings <= mids_movies, (
#     f"Brakujące movieId w df_movies (z ratingów): {mids_ratings - mids_movies}"
# )
#
# mids_users = set(x for lst in df_users['movies_seq'] for x in lst)
# assert mids_users <= mids_movies, (
#     f"Brakujące movieId w df_movies (z użytkowników): {mids_users - mids_movies}"
# )
#
# print("Sanity check: wszystkie userId i movieId się pokrywają")

In [501]:
mids_pos = set(x for lst in df_ratings['pos'] for x in lst)
mids_neg = set(x for lst in df_ratings['neg'] for x in lst)
missing = (mids_pos | mids_neg) - set(df_movies['movieId'])


In [502]:
import pandas as pd

pos_user_counts = {
    m: df_ratings['pos'].map(lambda lst: m in lst).sum()
    for m in missing
}
neg_user_counts = {
    m: df_ratings['neg'].map(lambda lst: m in lst).sum()
    for m in missing
}

df_missing_stats = (
    pd.DataFrame({
        'pos_users': pos_user_counts,
        'neg_users': neg_user_counts,
    })
    .sort_values(['pos_users','neg_users'], ascending=False)
)
print(df_missing_stats)


        pos_users  neg_users
198149          4          7
240896          1          2
164409          0          3
292229          0          2
292617          0          2
292597          0          2
235259          0          2
292737          0          1
277832          0          1
290385          0          1
292531          0          1
283477          0          1
292093          0          1
292063          0          1


In [503]:
valid_ids = set(df_movies['movieId'])
df_ratings['pos'] = df_ratings['pos'].apply(lambda lst: [m for m in lst if m in valid_ids])
df_ratings['neg'] = df_ratings['neg'].apply(lambda lst: [m for m in lst if m in valid_ids])

df_ratings = df_ratings[df_ratings['pos'].map(len).gt(0) & df_ratings['neg'].map(len).gt(0)]

# Przygotowanie movieId dla datasetów

In [504]:
'''
Sanity check do LOOCV
'''
single_pos_users = (df_ratings['pos'].apply(len) == 1).sum()

print(f"Liczba użytkowników z dokładnie jednym pozytywnym ratingiem: {single_pos_users}")

Liczba użytkowników z dokładnie jednym pozytywnym ratingiem: 262


In [505]:
'''
Do sprzatniecia userow (wyruczamy np. 262 uzytkownikow z 1 pos)
'''
pos_counts = df_ratings['pos'].apply(len)
neg_counts = df_ratings['neg'].apply(len)

mask = (pos_counts >= 2) & (neg_counts >= 2) # Wycinamy userow o wybranych parametrach
df_ratings = df_ratings.loc[mask]

common_users = set(df_ratings['userId'])
df_users   = df_users[df_users['userId'].isin(common_users)].reset_index(drop=True)
df_ratings = df_ratings[df_ratings['userId'].isin(common_users)].copy()

df_ratings = df_ratings.set_index('userId', drop=False)
df_movies  = df_movies.set_index('movieId', drop=False) # Bardzo wazne zeby wiedziec ze jest to 'movieId'

# Jezeli dodadtkowo usuwamy filmy ktore sie nie pojawiaja w ocenach (oszczedzamy pamiec)
used_movie_ids = set(df_users['movies_seq'].explode()) \
               | set(df_ratings['pos'].explode()) \
               | set(df_ratings['neg'].explode())
df_movies = df_movies.loc[df_movies.index.isin(used_movie_ids)].copy()

single_pos_users = (df_ratings['pos'].apply(len) == 1).sum()

print(f"Liczba użytkowników z dokładnie jednym pozytywnym ratingiem: {single_pos_users}")

Liczba użytkowników z dokładnie jednym pozytywnym ratingiem: 0


In [506]:
print(df_users.info())
print(df_ratings.info())
print(df_movies.info())

empty_pos_ratings = df_ratings['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = df_ratings['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 196886 entries, 0 to 196885
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   196886 non-null  int64  
 1   num_rating               196886 non-null  float64
 2   avg_rating               196886 non-null  float64
 3   weekend_watcher          196886 non-null  float64
 4   genre_Action             196886 non-null  float64
 5   genre_Adventure          196886 non-null  float64
 6   genre_Animation          196886 non-null  float64
 7   genre_Comedy             196886 non-null  float64
 8   genre_Crime              196886 non-null  float64
 9   genre_Documentary        196886 non-null  float64
 10  genre_Drama              196886 non-null  float64
 11  genre_Family             196886 non-null  float64
 12  genre_Fantasy            196886 non-null  float64
 13  genre_History            196886 non-null  float64
 14  genr

In [507]:
unique_ids = set(
        df_users['movies_seq'].explode().tolist()
        + df_ratings['pos'].explode().tolist() 
        + df_ratings['neg'].explode().tolist()
    )

print('Unique movieIds:', len(unique_ids))
unique_ids = sorted(unique_ids)

movieId_to_idx = {id_: idx for idx, id_ in enumerate(unique_ids)}
print('min idx:', min(movieId_to_idx.values()))
print('max idx:', max(movieId_to_idx.values()))

n_items = len(unique_ids)

assert min(movieId_to_idx.values()) == 0
assert max(movieId_to_idx.values()) == n_items - 1

Unique movieIds: 82911
min idx: 0
max idx: 82910


In [508]:
# Zmapuj movieId do indeksów
df_users['movies_seq'] = df_users['movies_seq'].apply(lambda lst: [movieId_to_idx[m] for m in lst])
df_ratings['pos'] = df_ratings['pos'].apply(lambda lst: [movieId_to_idx[m] for m in lst])
df_ratings['neg'] = df_ratings['neg'].apply(lambda lst: [movieId_to_idx[m] for m in lst])
df_ratings = df_ratings.set_index('userId')

# df_movies musi być ograniczone tylko do używanych filmów
df_movies = df_movies[df_movies['movieId'].isin(movieId_to_idx)].copy()
df_movies['movieId'] = df_movies['movieId'].map(movieId_to_idx)
df_movies = df_movies.set_index('movieId')

# Final sanity check
assert df_users['movies_seq'].explode().max() < n_items
assert df_ratings['pos'].explode().max() < n_items
assert df_ratings['neg'].explode().max() < n_items

assert df_movies.index.max() < n_items
assert df_movies.index.notna().all()

# assert df_movies['movieId'].max() < n_items
# assert df_movies['movieId'].notna().all(), "Some movieIds weren't mapped!"

In [509]:
max_movie_idx = df_users['movies_seq'].explode().max()
print("max_movie_idx =", max_movie_idx)
print("n_items =", n_items)

assert max_movie_idx < n_items, "Indeks filmu przekracza rozmiar embeddingu"

max_movie_idx = 82910
n_items = 82911


In [510]:
def has_invalid_entries(seq_col):
    return seq_col.explode().isin([-1, np.nan, None]).any()

print("Zawiera niepoprawne wartości:", has_invalid_entries(df_users['movies_seq']))

Zawiera niepoprawne wartości: False


In [511]:
df_users.info()
df_users.head(100)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 196886 entries, 0 to 196885
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   196886 non-null  int64  
 1   num_rating               196886 non-null  float64
 2   avg_rating               196886 non-null  float64
 3   weekend_watcher          196886 non-null  float64
 4   genre_Action             196886 non-null  float64
 5   genre_Adventure          196886 non-null  float64
 6   genre_Animation          196886 non-null  float64
 7   genre_Comedy             196886 non-null  float64
 8   genre_Crime              196886 non-null  float64
 9   genre_Documentary        196886 non-null  float64
 10  genre_Drama              196886 non-null  float64
 11  genre_Family             196886 non-null  float64
 12  genre_Fantasy            196886 non-null  float64
 13  genre_History            196886 non-null  float64
 14  genr

Unnamed: 0,userId,num_rating,avg_rating,weekend_watcher,genre_Action,genre_Adventure,genre_Animation,genre_Comedy,genre_Crime,genre_Documentary,...,genre_TV Movie,genre_Thriller,genre_War,genre_Western,type_of_viewer_negative,type_of_viewer_neutral,type_of_viewer_positive,movies_seq,ratings_seq,ts_seq
0,1,-0.068675,-0.347979,0.0,0.926736,-0.375240,-0.179705,-0.402570,0.892727,-1.076257,...,-0.284700,-0.022680,-0.426892,-0.911843,0.0,1.0,0.0,"[24, 1013, 1314, 1360, 1619, 303, 1027, 1190, ...","[-2.3984964034019467, 1.3836304001080941, -2.3...","[-1.2878777024141752, -1.2878663519376752, -1...."
1,2,-0.383417,1.210645,0.0,0.713096,0.940526,1.581734,0.973277,0.410751,0.788073,...,1.018523,1.084306,-1.207360,0.791019,0.0,0.0,1.0,"[30, 191, 273, 545, 234, 577, 503, 216, 376, 2...","[1.3836304001080941, -0.5074330016469263, 0.43...","[-1.709033992628413, -1.709033992628413, -1.70..."
2,3,-0.047456,-0.228499,0.0,0.045472,0.066743,0.383741,-0.669363,-0.692235,-0.211924,...,-0.184799,-0.650122,0.197482,0.094394,0.0,1.0,0.0,"[5218, 4768, 5679, 6196, 3893, 6222, 6391, 516...","[-0.5074330016469263, -0.5074330016469263, -0....","[-0.740135863595513, -0.7401358518778841, -0.7..."
3,4,-0.471828,-2.255334,0.0,-1.763184,-0.917026,-1.363040,-0.993442,-2.681933,-1.620281,...,-1.879511,-2.686105,-1.909780,-1.408750,1.0,0.0,0.0,"[2573, 2589, 2600, 2660, 220, 2612, 2770, 3091...","[-0.5074330016469263, -1.4529647025244365, -1....","[-1.2244648204630921, -1.2244648204630921, -1...."
4,5,-0.450609,-0.895880,0.0,0.178997,-0.375240,-0.002980,-0.993442,-1.908762,-0.675657,...,-0.742820,-0.372255,0.197482,-0.911843,0.0,1.0,0.0,"[228, 312, 159, 288, 314, 324, 429, 9, 183, 25...","[-1.4529647025244365, -0.5074330016469263, 0.4...","[-1.6920817200700495, -1.6920817200700495, -1...."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,97,-0.401099,0.439117,1.0,0.112234,0.564918,0.871345,0.478688,0.410751,0.251973,...,0.373421,0.162700,-0.153728,1.771455,0.0,1.0,0.0,"[228, 16, 5087, 3633, 3631, 2063, 1207, 3876, ...","[-0.5074330016469263, 0.43809869923058387, 0.4...","[-0.4013318737005198, -0.40133182683000435, -0..."
96,98,-0.464755,0.131258,0.0,1.514246,0.708332,0.434183,-0.518954,0.410751,0.038055,...,0.116008,0.559944,0.197482,0.105960,0.0,1.0,0.0,"[585, 5815, 5679, 5673, 2747, 2573, 2589, 2596...","[1.3836304001080941, 1.3836304001080941, 0.438...","[-0.9020311080297979, -0.9020306354187667, -0...."
97,99,-0.478901,1.106575,0.0,0.178997,0.301993,1.745670,1.790221,0.410751,0.715759,...,0.931507,0.809640,0.899903,0.724969,0.0,0.0,1.0,"[292, 2463, 2747, 731, 14840, 1166, 1057, 1194...","[-0.034667151208171196, 0.43809869923058387, -...","[0.6478568622869465, 0.6478568740045754, 0.647..."
98,100,0.313259,-0.381486,0.0,-0.544263,-0.397290,-0.574654,-0.042357,-0.095325,-0.318228,...,-0.312716,-0.055048,-0.613003,1.771455,0.0,1.0,0.0,"[3631, 1058, 1336, 1248, 1951, 1326, 15131, 20...","[-0.5074330016469263, -0.034667151208171196, 0...","[1.3519862967238052, 1.3519863514060733, 1.351..."


In [512]:
df_ratings.info()
df_ratings.head(100)

<class 'pandas.core.frame.DataFrame'>
Index: 196886 entries, 1 to 200948
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype 
---  ------  --------------   ----- 
 0   pos     196886 non-null  object
 1   neg     196886 non-null  object
dtypes: object(2)
memory usage: 4.5+ MB


Unnamed: 0_level_0,pos,neg
userId,Unnamed: 1_level_1,Unnamed: 2_level_1
1,"[16, 29, 31, 79, 109, 164, 174, 229, 257, 298,...","[24, 28, 33, 35, 108, 159, 220, 340, 351, 522,..."
2,"[30, 33, 38, 47, 183, 184, 205, 214, 216, 219,...","[151, 191, 228, 250, 292, 301, 339, 344, 461, ..."
3,"[9, 10, 16, 25, 61, 108, 148, 149, 159, 257, 2...","[1, 47, 139, 151, 156, 166, 183, 206, 228, 324..."
4,"[220, 1232, 2011, 2660, 2731, 3063]","[1172, 1285, 1452, 1732, 2320, 2382, 2491, 249..."
5,"[9, 108, 159, 163, 344, 351, 359, 375, 429, 44...","[46, 148, 151, 183, 206, 228, 250, 285, 288, 2..."
...,...,...
97,"[16, 351, 359, 375, 475, 542, 582, 585, 835, 1...","[148, 205, 206, 228, 1241, 1465, 2063, 2165, 2..."
98,"[585, 878, 882, 894, 1065, 1157, 2016, 2247, 2...","[1, 1808, 2388, 2589, 2660, 2661, 2770, 2886]"
99,"[49, 109, 289, 314, 536, 585, 599, 731, 873, 1...","[292, 2747, 14840]"
100,"[0, 1, 46, 49, 257, 289, 292, 314, 352, 495, 5...","[4, 30, 33, 102, 148, 163, 351, 359, 372, 475,..."


In [513]:
df_movies.info()
df_movies.head(100)

<class 'pandas.core.frame.DataFrame'>
Index: 82903 entries, 14840 to 29524
Data columns (total 28 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   runtime              82903 non-null  float64
 1   if_blockbuster       82903 non-null  int32  
 2   highly_watched       82903 non-null  int32  
 3   highly_rated         82903 non-null  int64  
 4   engagement_score     82903 non-null  float64
 5   cast_importance      82903 non-null  float64
 6   director_score       82903 non-null  float64
 7   has_keywords         82903 non-null  int64  
 8   has_cast             82903 non-null  int64  
 9   has_director         82903 non-null  int64  
 10  genre_ids            82903 non-null  object 
 11  decade_[1890, 1900)  82903 non-null  bool   
 12  decade_[1900, 1910)  82903 non-null  bool   
 13  decade_[1910, 1920)  82903 non-null  bool   
 14  decade_[1920, 1930)  82903 non-null  bool   
 15  decade_[1930, 1940)  82903 non-null  

Unnamed: 0_level_0,runtime,if_blockbuster,highly_watched,highly_rated,engagement_score,cast_importance,director_score,has_keywords,has_cast,has_director,...,"decade_[1960, 1970)","decade_[1970, 1980)","decade_[1980, 1990)","decade_[1990, 2000)","decade_[2000, 2010)","decade_[2010, 2020)","decade_[2020, 2030)",text_embedded,actor_ids,director_ids
movieId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
14840,1.942703,1,1,1,4.591432,2.899632,2.653210,1,1,1,...,False,False,False,False,False,True,False,"[0.0008652941, 0.06077885, -0.07869467, -0.067...","[6454, 10631, 5457, 1952, 5950]",[797]
20922,2.432017,1,1,1,5.199338,2.789332,2.653210,1,1,1,...,False,False,False,False,False,True,False,"[-0.010866538, -0.01691181, -0.12693988, -0.04...","[659, 7298, 4974, 10576, 5292]",[797]
12164,2.033104,1,1,1,5.199338,3.099369,2.653210,1,1,1,...,False,False,False,False,True,False,False,"[-0.026262647, 0.055052526, -0.08173301, -0.01...","[1867, 3519, 7812, 1952, 4010]",[797]
14021,2.256745,1,1,1,4.123958,2.512635,2.304477,1,1,1,...,False,False,False,False,True,False,False,"[0.0031084684, -0.032840427, -0.12393689, -0.0...","[11434, 9935, 7629, 9574, 3709]",[2026]
16934,1.824556,1,1,1,5.199338,5.199338,1.817788,1,1,1,...,False,False,False,False,False,True,False,"[-0.015282603, 0.00047473708, -0.11172164, 0.0...","[9686, 1839, 1834, 9161, 4923]",[2496]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1159,1.358913,1,1,1,3.003738,1.826226,0.704639,1,1,1,...,False,False,True,False,False,False,False,"[0.016825777, -0.02596536, -0.088541396, 0.058...","[3952, 693, 3373, 7093]",[1915]
4180,-0.276998,1,1,1,3.144047,2.695707,1.599158,1,1,1,...,False,False,False,False,True,False,False,"[0.05683445, -0.0026107691, -0.090642, 0.01653...","[2810, 1497, 7670, 10954, 5022]","[246, 4884]"
11313,0.908975,1,1,1,3.133301,1.566170,1.685018,1,1,1,...,False,False,False,False,True,False,False,"[-0.030460857, -0.0006409547, 0.00023257566, -...","[11136, 8598, 1274, 4226, 8408]",[544]
23316,1.497997,1,1,1,3.111676,2.686499,1.898700,1,1,1,...,False,False,False,False,False,True,False,"[-0.0011530533, 0.014149291, -0.10900124, -0.0...","[9597, 7471, 7114, 2054, 10367]",[3179]


In [514]:
'''
Do szybkich testow z mniejsza iloscia danych
'''
DEBUG = False

if DEBUG:
    df_users = df_users.sample(n=196886, random_state=213).copy()

    mask = df_ratings.index.isin(df_users['userId'])
    df_ratings = df_ratings[mask].copy()

    # used_movie_ids = set(df_users['movies_seq'].explode()) \
    #                | set(df_ratings['pos'].explode()) \
    #                | set(df_ratings['neg'].explode())
    # df_movies = df_movies[df_movies.index.isin(used_movie_ids)].copy()

movie_to_local = {mid: i for i, mid in enumerate(df_movies.index)}
local_to_movie = list(df_movies.index)

In [515]:
df_movies.info()
df_ratings.info()
df_users.info()

<class 'pandas.core.frame.DataFrame'>
Index: 82903 entries, 14840 to 29524
Data columns (total 28 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   runtime              82903 non-null  float64
 1   if_blockbuster       82903 non-null  int32  
 2   highly_watched       82903 non-null  int32  
 3   highly_rated         82903 non-null  int64  
 4   engagement_score     82903 non-null  float64
 5   cast_importance      82903 non-null  float64
 6   director_score       82903 non-null  float64
 7   has_keywords         82903 non-null  int64  
 8   has_cast             82903 non-null  int64  
 9   has_director         82903 non-null  int64  
 10  genre_ids            82903 non-null  object 
 11  decade_[1890, 1900)  82903 non-null  bool   
 12  decade_[1900, 1910)  82903 non-null  bool   
 13  decade_[1910, 1920)  82903 non-null  bool   
 14  decade_[1920, 1930)  82903 non-null  bool   
 15  decade_[1930, 1940)  82903 non-null  

# Przygotowanie danych do uczenia -> do gotowych batchy

In [516]:
# For padding 'global max len'

max_len_a = int(df_movies['actor_ids'].str.len().max())
max_len_d = int(df_movies['director_ids'].str.len().max())
max_len_g = int(df_movies['genre_ids'].str.len().max())

In [517]:
# For nn.Embeedings usage in Item Tower
all_actor_ids = list(chain.from_iterable(df_movies['actor_ids']))
num_actors = max(all_actor_ids) + 1

all_director_ids = list(chain.from_iterable(df_movies['director_ids']))
num_directors = max(all_director_ids) + 1

all_genre_ids = list(chain.from_iterable(df_movies['genre_ids']))
num_genres = max(all_genre_ids) + 1

print(num_actors, num_directors, num_genres)

11465 5163 20


In [518]:
def collect_user_features(u):
        """
        Zwraca cztery tensory: movies_seq, ratings_seq, ts_seq, user_stats
        """
        movies_seq  = torch.tensor(u['movies_seq'], dtype=torch.long)
        ratings_seq = torch.tensor(u['ratings_seq'], dtype=torch.float32)
        ts_seq      = torch.tensor(u['ts_seq'], dtype=torch.float32)
       
        stats_cols  = [c for c in u.index if c.startswith(('num_rating','avg_rating','weekend_watcher','genre_','type_of_viewer_'))]
        user_stats  = torch.tensor(u[stats_cols]
                                        .astype('float32').values,dtype=torch.float32)

        return movies_seq, ratings_seq, ts_seq, user_stats

In [519]:
def collect_movie_features(m, max_len_a, max_len_d, max_len_g):
        """
        Zwraca cztery tensory: combined, actor_ids, director_ids, genre_ids
        """
        numeric = [
            m.runtime,
            m.engagement_score,
            m.cast_importance,
            m.director_score,
        ]
        binary = [
            m.if_blockbuster,
            m.highly_watched,
            m.highly_rated,
            m.has_keywords,
            m.has_cast,
            m.has_director,
        ]
        decades = (m[[c for c in m.index if c.startswith('decade_')]]
                   .astype(int)
                   .tolist())

        dense_feats = torch.tensor(numeric + binary + decades, dtype=torch.float32)
        text_emb = torch.tensor(m.text_embedded, dtype=torch.float32)

        def pad(seq, L):
            seq_list = list(seq) if not isinstance(seq, list) else seq
            padded = seq_list[:L] + [0] * max(0, L - len(seq_list))
            return torch.tensor(padded, dtype=torch.long)

        actor_ids    = pad(m.actor_ids,    max_len_a)
        director_ids = pad(m.director_ids, max_len_d)
        genre_ids    = pad(m.genre_ids,    max_len_g)

        return dense_feats, text_emb, actor_ids, director_ids, genre_ids

In [520]:
import faiss
'''
Do zbudowania macierzy embeedingow dla FAISS, do szyukania najblizszych sasiadow
'''

# unique_ids = df_movies.index.tolist()
movie_vecs = []

for m_id in df_movies.index:
    dense_feats, text_emb, *_ = collect_movie_features(
        df_movies.loc[m_id],
        max_len_a, max_len_d, max_len_g
    )
    combined = torch.cat([dense_feats, text_emb], dim=0)
    # normalizujemy L2 na potrzeby FAISS cosinusowego (wyplaszczanie)
    movie_vecs.append(F.normalize(combined, dim=0))

movie_matrix = torch.stack(movie_vecs)  # macierz [n_movies, D]
movie_matrix_np = movie_matrix.cpu().numpy().astype('float32')
# FAISS IP po L2-normalizacji = cosine similarity
faiss_index = faiss.IndexFlatIP(movie_matrix_np.shape[1])
faiss_index.add(movie_matrix_np)

In [521]:
# DO OCENY I EWENTUALNYCH ZMIAN
def find_negative(pos_id, user_negs, k, top_k=200):
    """
    Dla danego pozytywu (pos_id) szuka w FAISS najbliższego negatywu z listy user_negs. Jeśli żaden z top_k nie należy do user_negs to fallback = losowy wybór z user_negs.
    """
    local_pos = movie_to_local[pos_id]

    # Zakladamy co najmniej jeden pos_id
    _, I = faiss_index.search(movie_matrix_np[local_pos].reshape(1, -1), top_k)

    negs = []
    for candidate in I[0]:
        global_candidate = local_to_movie[candidate]
        if global_candidate in user_negs and global_candidate not in negs:
            negs.append(global_candidate)
            if len(negs) == k:
                return negs

    top_candidates = [local_to_movie[i] for i in I[0] if local_to_movie[i] not in negs]
    to_add = random.sample(
        top_candidates if top_candidates else list(user_negs),
        k - len(negs)
    )
    negs.extend(to_add)
    return negs

In [522]:
def find_negative_mixed(pos_id, user_negs, k, top_k=200, hard_frac=0.5):
    """
    Zwraca K negatywów:
     - k_h = int(k * hard_frac) twardych z FAISS
     - k_r = k-k_h losowych z całego zbioru użytkownika
    """
    if len(user_negs) <= k:
        return random.choices(list(user_negs), k=k)

    k_h = int(k * hard_frac)
    k_r = k - k_h

    local_pos = movie_to_local[pos_id]
    _, I = faiss_index.search(movie_matrix_np[local_pos].reshape(1, -1), top_k)
    hard_cands = [
        local_to_movie[i] for i in I[0]
        if local_to_movie[i] in user_negs
    ]
    hard = random.sample(hard_cands, min(k_h, len(hard_cands)))

    remaining = list(user_negs - set(hard))
    rand = random.sample(remaining, min(k_r, len(remaining)))

    all_negs = hard + rand

    if len(all_negs) < k:
        more = list(user_negs - set(all_negs))
        add_n = min(k - len(all_negs), len(more))
        all_negs.extend(random.sample(more, add_n))

    return all_negs

In [523]:
import collections

user_id = df_ratings.index[0]
pos_id = df_ratings.at[user_id, 'pos'][0]
neg_set = set(df_ratings.at[user_id, 'neg'])

k = 5
top_k = 100
trials = 1000

samples = [find_negative_mixed(pos_id, neg_set, k=k, top_k=top_k) for _ in range(trials)]

flat = [nid for sublist in samples for nid in sublist]
counter = collections.Counter(flat)

df_counts = pd.DataFrame(counter.items(), columns=['movie_id', 'count']).sort_values('count', ascending=False)

print("NEG: ", df_counts)

NEG:      movie_id  count
1       1694    431
9       2316    431
0       2161    429
10      1820    428
5       2602    419
36      1314     76
34      1964     67
2       1263     66
22      1713     65
48        28     64
18      1360     63
49      1961     62
14      2394     62
42      1160     61
11      2005     61
26       108     60
38      1841     60
46      2580     59
3       2299     59
23        24     57
30      2216     56
27        35     55
39      2126     55
52       599     55
47      2244     54
45      1862     54
15      1255     54
7       2137     54
33       351     54
24       220     53
50       812     53
53      1167     53
56       585     53
35       945     53
29      2966     52
32       886     52
20      2204     52
6       1173     52
28      1137     52
21      1857     52
17      1172     52
40      1202     51
12      2771     51
8       1086     50
51      1102     50
57        33     50
31      1925     48
44       522     48
54      1236  

In [524]:
class UserOnlyDataset(Dataset):
    def __init__(self, df_users):
        self.df_users = df_users.reset_index(drop=True)

    def __len__(self):
        return len(self.df_users)

    def __getitem__(self, idx):
        u_row = self.df_users.iloc[idx]
        movies_seq, ratings_seq, ts_seq, user_stats = collect_user_features(u_row)
        return {
            'user': {
                'user_statistics': user_stats,
                'movies': movies_seq,
                'ratings': ratings_seq,
                'times': ts_seq,
            }
        }

In [525]:
class MovieDataset(Dataset):
    '''
    Potrzebny do stworzenia matrix-a pod LOOCV
    '''
    def __init__(self, df_movies):
        self.df = df_movies
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        m = self.df.iloc[idx]
        return collect_movie_features(m, max_len_a, max_len_d, max_len_g)

In [526]:
class TwoTowerDataset(Dataset):

    def __init__(self, df_users, df_ratings, df_movies):
        self.df_users = df_users.reset_index(drop=True)
        self.df_ratings = df_ratings
        self.df_movies = df_movies

    def __len__(self):
        return len(self.df_users)

    def __getitem__(self, idx):
        # User features
        u_row = self.df_users.iloc[idx]
        movies_seq, ratings_seq, ts_seq, user_stats = collect_user_features(u_row)
        user_id = u_row['userId']

        pos_list = self.df_ratings.at[user_id, 'pos']
        neg_list = self.df_ratings.at[user_id, 'neg']

        #BPR
        pos_id = random.choice(pos_list)
        # neg_id = find_negative(pos_id,set(neg_list))
        k = 2 # Liczba negatywow
        # neg_ids = find_negative_mixed(pos_id, set(neg_list), k)

        # Szybki test na randomowych
        if len(neg_list) >= 2:
            neg_ids = random.sample(neg_list, k=2)
        else:
            neg_ids = random.choices(neg_list, k=2)

        assert pos_id not in neg_ids, "Wylosowałeś negatyw równy pozytywowi!"
        assert len(neg_ids) == k,      "Zła liczba negatywów"
        # assert len(set(neg_ids)) == k, "Duplikaty wśród negatywów"

        m_pos = self.df_movies.loc[pos_id]
        pos_feats, pos_text, pos_actors, pos_directors, pos_genres = collect_movie_features(m_pos, max_len_a, max_len_d, max_len_g)

        # m_neg = self.df_movies.loc[neg_id]
        # neg_feats, neg_text, neg_actors, neg_directors, neg_genres = collect_movie_features(m_neg, max_len_a, max_len_d, max_len_g)

        neg_feats_list, neg_text_list, neg_actor_list, neg_director_list, neg_genre_list = [], [], [], [], []
        for nid in neg_ids:
            m_neg = self.df_movies.loc[nid]
            nf, nt, na, nd, ng = collect_movie_features(m_neg, max_len_a, max_len_d, max_len_g)
            neg_feats_list.append(nf)
            neg_text_list.append(nt)
            neg_actor_list.append(na)
            neg_director_list.append(nd)
            neg_genre_list.append(ng)

        return {
            'user': {
                'user_statistics': user_stats,
                'movies': movies_seq,
                'ratings': ratings_seq,
                'times': ts_seq,
            },
            'pos_item': {
                'dense_features': pos_feats,
                'text_embedding': pos_text,
                'actor_ids': pos_actors,
                'director_ids': pos_directors,
                'genre_ids': pos_genres,
            },
            # 'neg_item': {
            #     'dense_features': neg_feats,
            #     'text_embedding': neg_text,
            #     'actor_ids': neg_actors,
            #     'director_ids': neg_directors,
            #     'genre_ids': neg_genres,
            # }
            'neg_item': {
                'dense_features':  torch.stack(neg_feats_list),    # [k, dense_feat_dim]
                'text_embedding':  torch.stack(neg_text_list),     # [k, text_emb_dim]
                'actor_ids':       torch.stack(neg_actor_list),    # [k, max_len_a]
                'director_ids':    torch.stack(neg_director_list), # [k, max_len_d]
                'genre_ids':       torch.stack(neg_genre_list),    # [k, max_len_g]
            }
        }

In [527]:
'''
TEST DATASETU I ODPOWIEDNIEGO OUTPUTU POJEDYNCZEGO OBIEKTU GET_ITEM
'''
dataset_test = TwoTowerDataset(df_users, df_ratings, df_movies)

sample0 = dataset_test[0]

print("Keys:", sample0.keys())
print("\n--- USER ---")
for k,v in sample0['user'].items():
    print(f" user[{k}]:", type(v), getattr(v, "shape", v[:5] if isinstance(v,list) else v))

print("\n--- POS ITEM ---")
for k,v in sample0['pos_item'].items():
    print(f" pos_item[{k}]:", type(v), v.shape if hasattr(v,'shape') else v[:5])

print("\n--- NEG ITEM ---")
for k,v in sample0['neg_item'].items():
    print(f" neg_item[{k}]:", type(v), v.shape if hasattr(v,'shape') else v[:5])

Keys: dict_keys(['user', 'pos_item', 'neg_item'])

--- USER ---
 user[user_statistics]: <class 'torch.Tensor'> torch.Size([25])
 user[movies]: <class 'torch.Tensor'> torch.Size([20])
 user[ratings]: <class 'torch.Tensor'> torch.Size([20])
 user[times]: <class 'torch.Tensor'> torch.Size([20])

--- POS ITEM ---
 pos_item[dense_features]: <class 'torch.Tensor'> torch.Size([24])
 pos_item[text_embedding]: <class 'torch.Tensor'> torch.Size([300])
 pos_item[actor_ids]: <class 'torch.Tensor'> torch.Size([5])
 pos_item[director_ids]: <class 'torch.Tensor'> torch.Size([3])
 pos_item[genre_ids]: <class 'torch.Tensor'> torch.Size([9])

--- NEG ITEM ---
 neg_item[dense_features]: <class 'torch.Tensor'> torch.Size([2, 24])
 neg_item[text_embedding]: <class 'torch.Tensor'> torch.Size([2, 300])
 neg_item[actor_ids]: <class 'torch.Tensor'> torch.Size([2, 5])
 neg_item[director_ids]: <class 'torch.Tensor'> torch.Size([2, 3])
 neg_item[genre_ids]: <class 'torch.Tensor'> torch.Size([2, 9])


In [528]:
def collate_TT(batch):
    '''
    Pelny batchowanie danych do uczenia
    '''
    user_movies, user_ratings, user_times, user_stats = [], [], [], []
    pos_dense, pos_text, pos_actor, pos_director, pos_genre = [], [], [], [], []
    neg_dense, neg_text, neg_actor, neg_director, neg_genre = [], [], [], [], []

    for row in batch:

        user_stats.append(row['user']['user_statistics'])
        user_movies.append(row['user']['movies'])
        user_ratings.append(row['user']['ratings'])
        user_times.append(row['user']['times'])

        pos_dense.append(row['pos_item']['dense_features'])
        pos_text.append(row['pos_item']['text_embedding'])
        pos_actor.append(row['pos_item']['actor_ids'])
        pos_director.append(row['pos_item']['director_ids'])
        pos_genre.append(row['pos_item']['genre_ids'])

        neg_dense.append(row['neg_item']['dense_features']) # [k, D_feat]
        neg_text.append(row['neg_item']['text_embedding'])  # [k, D_text]
        neg_actor.append(row['neg_item']['actor_ids'])
        neg_director.append(row['neg_item']['director_ids'])
        neg_genre.append(row['neg_item']['genre_ids'])

    batch_user = {
        'user_statistics': torch.stack(user_stats),     # [B, d_stats]
        'movies': torch.stack(user_movies),             # [B, L_u]
        'ratings': torch.stack(user_ratings),           # [B, L_u]
        'times': torch.stack(user_times),               # [B, L_u]
    }

    batch_pos_item = {
        'dense_features': torch.stack(pos_dense),       # [B, dense_feat_dim]
        'text_embedding': torch.stack(pos_text),        # [B, text_emb_dim]
        'actor_ids': torch.stack(pos_actor),            # [B, max_len_a]
        'director_ids':torch.stack(pos_director),       # [B, max_len_d]
        'genre_ids': torch.stack(pos_genre),            # [B, max_len_g]
    }

    batch_neg_item = {
        'dense_features': torch.stack(neg_dense),
        'text_embedding': torch.stack(neg_text),
        'actor_ids': torch.stack(neg_actor),
        'director_ids': torch.stack(neg_director),
        'genre_ids': torch.stack(neg_genre),
    }

    return {
      'user': batch_user,
      'pos_item': batch_pos_item,
      'neg_item': batch_neg_item
    }

In [529]:
def collateUser(batch):
    '''
    Przygotowujemy batch zawierajace dane tylko user-a, potrzebne do leave-one-out
    '''
    movies, ratings, times, stats = [], [], [], []

    for row in batch:

        movies.append(row['user']['movies'])
        ratings.append(row['user']['ratings'])
        times.append(row['user']['times'])
        stats.append(row['user']['user_statistics'])

    return {
        'user': {
            'user_statistics': torch.stack(stats),  # [B, d_stats]
            'movies': torch.stack(movies),          # [B, L_u]
            'ratings': torch.stack(ratings),        # [B, L_u]
            'times': torch.stack(times)             # [B, L_u]
        }
    }

# Przygotowanie zbiorów do treningu

In [530]:
BATCH_SIZE = 8192 # FOR TEST: 4
train_users, val_users = train_test_split(
    df_users,
    test_size=0.2,
    random_state=213
)

train_ratings = df_ratings[df_ratings.index.isin(train_users['userId'])].copy()
val_ratings = df_ratings[df_ratings.index.isin(val_users['userId'])].copy()

assert set(train_users['userId']) <= set(train_ratings.index)
assert set(val_users  ['userId']) <= set(val_ratings.index)

In [531]:
'''
Tworzymy do pozniejszej walidacji leave-one-out w heavy_evaluate
'''

val_ratings_heavy = val_ratings.copy()
val_loocv = []

for user_id, row in val_ratings_heavy.iterrows():
    if len(row['pos']) < 2:
        continue                    # pomijamy jezeli > 2
    pos_list = row['pos']           # wwszytkie pos (wieksze od > 2)
    hold = pos_list[-1]             # Bierzemy ostatni do hold-out
    train = pos_list[:-1]

    val_loocv.append({'userId': user_id, 'pos': [hold]})
    val_ratings_heavy.at[user_id, 'pos'] = train

val_loocv = pd.DataFrame(val_loocv).set_index('userId')

In [532]:
'''
Przygotowujemy dane potrzebne do leave-one-out
'''
# train_pos_sets_val = {
#     user_id: set(pos_list)
#     for user_id, pos_list in val_ratings['pos'].items()
# }
#
# test_pos_val = val_loocv['pos'].to_dict()

train_pos_sets_val_global = {                           # globalne zbiory pozytywów i hold-outów z val_ratings_heavy
    u: set(pos_list)
    for u, pos_list in val_ratings_heavy['pos'].items()
}
test_pos_val_global = val_loocv['pos'].to_dict()

train_pos_sets_val = {                                  # mapowanie na lokalne indeksy w movie_matrix_np
    u: { movie_to_local[mid] for mid in global_set }
    for u, global_set in train_pos_sets_val_global.items()
}
test_pos_val = {
    u: [ movie_to_local[mid] for mid in global_list ]
    for u, global_list in test_pos_val_global.items()
}

heavy_user_ids = val_loocv.index.tolist() # listę tylko tych userów z val, którzy mają >=2 pozytywów
heavy_users = df_users[df_users['userId'].isin(heavy_user_ids)].reset_index(drop=True)

all_user_ids = heavy_user_ids


In [533]:
train_dataset = TwoTowerDataset(
    train_users,
    train_ratings,
    df_movies
)
val_dataset_light = TwoTowerDataset(
    val_users,
    val_ratings,         # pełna lista pos i neg
    df_movies
)
# val_dataset_heavy = TwoTowerDataset(
#     val_users,
#     val_loocv,          # pojedyncze hold-out’y
#     df_movies
# )
val_dataset_heavy = UserOnlyDataset(heavy_users)

In [534]:
train_loader = DataLoader(
    dataset       = train_dataset,
    batch_size    = BATCH_SIZE,
    shuffle       = True,
    # num_workers   = 2,
    pin_memory    = True,
    collate_fn    = collate_TT,
    drop_last     = False
)
val_light_loader = DataLoader(
    dataset       = val_dataset_light,
    batch_size    = BATCH_SIZE,
    shuffle       = False,
    # num_workers   = 2,
    pin_memory    = True,
    collate_fn    = collate_TT,
    drop_last     = False
)
# val_heavy_loader = DataLoader(
#     dataset       = val_dataset_heavy,
#     batch_size    = BATCH_SIZE,
#     shuffle       = False,
#     # num_workers   = 2,
#     pin_memory    = True,
#     collate_fn    = collateUser,
#     drop_last     = False
# )
val_heavy_loader = DataLoader(
    dataset    = val_dataset_heavy,
    batch_size = BATCH_SIZE,
    shuffle    = False,
    pin_memory = True,
    collate_fn = collateUser,
)
movie_loader = DataLoader(
    MovieDataset(df_movies),
    batch_size=8192,
    collate_fn=lambda batch: {
        'pos_item': {
            'dense_features': torch.stack([b[0] for b in batch]),
            'text_embedding': torch.stack([b[1] for b in batch]),
            'actor_ids':      torch.stack([b[2] for b in batch]),
            'director_ids':   torch.stack([b[3] for b in batch]),
            'genre_ids':      torch.stack([b[4] for b in batch]),
        }
    }
)

In [535]:
'''
TEST CUSTOMOWEJ FUNKCJI collateTT I DATALOADER-OW
'''
device = torch.device("cuda")
dataset_test = TwoTowerDataset(df_users, df_ratings, df_movies)

loader_test_full = DataLoader(
    dataset_test,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_TT,
)

batch_test = next(iter(loader_test_full))

print("=== USER ===")
for k,v in batch_test['user'].items():
    print(f"{k:10s}:", v.shape)

print("\n=== POS ITEM ===")
for k,v in batch_test['pos_item'].items():
    print(f"{k:15s}:", v.shape)

print("\n=== NEG ITEM ===")
for k,v in batch_test['neg_item'].items():
    print(f"{k:15s}:", v.shape)

=== USER ===
user_statistics: torch.Size([4, 25])
movies    : torch.Size([4, 20])
ratings   : torch.Size([4, 20])
times     : torch.Size([4, 20])

=== POS ITEM ===
dense_features : torch.Size([4, 24])
text_embedding : torch.Size([4, 300])
actor_ids      : torch.Size([4, 5])
director_ids   : torch.Size([4, 3])
genre_ids      : torch.Size([4, 9])

=== NEG ITEM ===
dense_features : torch.Size([4, 2, 24])
text_embedding : torch.Size([4, 2, 300])
actor_ids      : torch.Size([4, 2, 5])
director_ids   : torch.Size([4, 2, 3])
genre_ids      : torch.Size([4, 2, 9])


In [536]:
loader_test_user = DataLoader(
    dataset_test,
    batch_size=4,
    shuffle=False,
    collate_fn=collateUser,
)

batch_user = next(iter(loader_test_user))

print("\n=== USER-ONLY BATCH (collateUser) ===")
for k, v in batch_user['user'].items():
    print(f"{k:12s} ->", v.shape)


=== USER-ONLY BATCH (collateUser) ===
user_statistics -> torch.Size([4, 25])
movies       -> torch.Size([4, 20])
ratings      -> torch.Size([4, 20])
times        -> torch.Size([4, 20])


In [537]:
batch_test_3 = next(iter(train_loader))

print("=== USER ===")
for k,v in batch_test_3['user'].items():
    print(f"{k:10s}:", v.shape)

print("\n=== POS ITEM ===")
for k,v in batch_test_3['pos_item'].items():
    print(f"{k:15s}:", v.shape)

print("\n=== NEG ITEM ===")
for k,v in batch_test_3['neg_item'].items():
    print(f"{k:15s}:", v.shape)

=== USER ===
user_statistics: torch.Size([8192, 25])
movies    : torch.Size([8192, 20])
ratings   : torch.Size([8192, 20])
times     : torch.Size([8192, 20])

=== POS ITEM ===
dense_features : torch.Size([8192, 24])
text_embedding : torch.Size([8192, 300])
actor_ids      : torch.Size([8192, 5])
director_ids   : torch.Size([8192, 3])
genre_ids      : torch.Size([8192, 9])

=== NEG ITEM ===
dense_features : torch.Size([8192, 2, 24])
text_embedding : torch.Size([8192, 2, 300])
actor_ids      : torch.Size([8192, 2, 5])
director_ids   : torch.Size([8192, 2, 3])
genre_ids      : torch.Size([8192, 2, 9])


# ARCHITEKTURA TWO TOWER

In [538]:
EMB_DIM = 64

class UserTower(nn.Module):
    def __init__(self, input_dim, n_items, embedding_dim=EMB_DIM):
        '''
        input_dim - the number of columns in user features, without sequence columns
        '''
        super().__init__()

        self.item_emb = nn.Embedding(n_items, embedding_dim)

        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights
        x = torch.stack([batch['ratings'], batch['times']], dim=-1) # [B, L_u, 2]
        w = torch.sigmoid(self.rating_proj(x))

        # weighted mean-pool
        pooled = (m * w).sum(1) / w.sum(1).clamp_min(1e-6)   # [B, D]

        input = torch.cat([batch['user_statistics'], pooled], dim=-1) # [B, stats+EMB_DIM]
        output = self.mlp(input)                                    # [B, EMB_DIM]
        u = F.normalize(output, dim = 1)
        return u


class ItemTower(nn.Module):
    def __init__(self,dense_feat_dim,text_emb_dim,vocab_sizes,embedding_dim=EMB_DIM):
        '''
        vocab_sizes - tuple odpowiednio n_actors, n_directors, n_genres
        dense_feat_dim – wymiary numeric+binary+decades+text
        tex_emb_dim - Wektor o wielkosc 300 opisujacy dane tekstowe filmu
        '''
        super().__init__()

        self.actor_emb = nn.Embedding(vocab_sizes[0], embedding_dim)
        self.director_emb = nn.Embedding(vocab_sizes[1], embedding_dim)
        self.genre_emb = nn.Embedding(vocab_sizes[2], embedding_dim)

        self.meta_mlp = nn.Sequential(
            nn.Linear(dense_feat_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, embedding_dim),
            nn.ReLU()
        )

        self.text_mlp = nn.Sequential( #--- to consider za ostre zejscie z 512 -> 64, moze posredni 256
            nn.Linear(text_emb_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, embedding_dim),
            nn.ReLU()
        )

        MLP_INPUT_DIM = embedding_dim*5 # odpowiednio nn.Embeedings * 3 oraz meta_mlp oraz text_mlp
        self.final_mlp = nn.Sequential(
            nn.Linear(MLP_INPUT_DIM, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256,embedding_dim)
        )

    def forward(self, batch, key: str = "pos_item"):

        dense_feats = batch[key]['dense_features']     # [B, dense_feat_dim]
        text_emb = batch[key]['text_embedding']     # [B, text_emb_dim]

        actor_ids = batch[key]['actor_ids']         # [B, max_len_a]
        director_ids = batch[key]['director_ids']
        genre_ids = batch[key]['genre_ids']

        if dense_feats.dim() == 3:
            B, k, Z = dense_feats.shape

            # flattenujemy
            dense_flat     = dense_feats.view(B*k, Z)
            text_flat      = text_emb.view(B*k, -1)
            actor_flat     = actor_ids.view(B*k, -1)
            director_flat  = director_ids.view(B*k, -1)
            genre_flat     = genre_ids.view(B*k, -1)

            # złożony batch
            flat_batch = {
                key: {
                    'dense_features':  dense_flat,
                    'text_embedding':  text_flat,
                    'actor_ids':       actor_flat,
                    'director_ids':    director_flat,
                    'genre_ids':       genre_flat,
                }
            }

            emb_flat = self.forward(flat_batch, key)    # rekurencyjnie batch na embeddingi [B*k, D]

            return emb_flat.view(B, k, -1)              # [B, k, D]

        dense_vec = self.meta_mlp(dense_feats)      # [B, D]
        text_vec = self.text_mlp(text_emb)          # [B, D]

        cast_imp = dense_feats[:, 2:3]              # [B, 1]
        director_score = dense_feats[:, 3:4]        # [B, 1]

        a = self.actor_emb   (actor_ids).mean(dim=1)    # [B, D]
        d = self.director_emb(director_ids).mean(dim=1) # [B, D]
        g = self.genre_emb   (genre_ids).mean(dim=1)    # [B, D]

        # We add weights based on importance score
        a = a * cast_imp
        d = d * director_score #--- do rozwazenia Max pooling lub Attention pooling

        input = torch.cat([a, d, g, dense_vec, text_vec], dim=-1)   # [B, 5D]
        output = self.final_mlp(input)                              # [B, D]
        i = F.normalize(output, dim=1)
        return i


In [539]:
class TwoTowerModel(nn.Module):
    def __init__(self, stats_dim, n_items, vocab_sizes,
                 dense_feat_dim, text_emb_dim, embedding_dim=EMB_DIM):
        super().__init__()
        self.user_tower = UserTower(stats_dim, n_items, embedding_dim)
        self.item_tower = ItemTower(dense_feat_dim, text_emb_dim, vocab_sizes, embedding_dim)

    def forward(self, batch):
        u = self.user_tower(batch['user'])
        i_pos = self.item_tower(batch, key="pos_item")
        i_neg = self.item_tower(batch, key="neg_item")

        if i_neg.dim() == 2:
            return u, i_pos, i_neg # każdy [B, 64]

        B, k, D = i_neg.shape

        i_neg_flat = i_neg.reshape(B*k, D) # Splaszczamy

        u_flat = u.unsqueeze(1).expand(B, k, D).reshape(B*k, D)
        pos_flat = i_pos.unsqueeze(1).expand(B, k, D).reshape(B*k, D)

        return u_flat, pos_flat, i_neg_flat


In [540]:
'''
TEST ARCHITEKTURY MODELOW
'''
device = torch.device("cuda")
model_test  = TwoTowerModel(stats_dim=25,
                       n_items=n_items,
                       vocab_sizes=(num_actors, num_directors, num_genres),
                       dense_feat_dim=24,
                       text_emb_dim=300,
                       embedding_dim=64).to(device)

# First batch
batch_test_2 = next(iter(loader_test_full))

batch_test_2 = {
  'user':      {k: v.to(device, non_blocking=True) for k,v in batch_test_2['user'].items()},
  'pos_item':  {k: v.to(device, non_blocking=True) for k,v in batch_test_2['pos_item'].items()},
  'neg_item':  {k: v.to(device, non_blocking=True) for k,v in batch_test_2['neg_item'].items()},
}

# Forward pass
u_test, i_pos_test, i_neg_test = model_test(batch_test_2)

print("u.shape:",     u_test.shape)      # -> [B, 64]
print("i_pos.shape:", i_pos_test.shape)  # -> [B, 64]
print("i_neg.shape:", i_neg_test.shape)  # -> [B, 64] Dla pojedynczego /  [B, k, 64] Dla wiecej negatywow


u.shape: torch.Size([8, 64])
i_pos.shape: torch.Size([8, 64])
i_neg.shape: torch.Size([8, 64])


# TRENOWANIE

In [541]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')
print('Device:', device)

Device: cuda


In [542]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device, non_blocking=True)
    else:
        return data

In [543]:
'''
Przygotowanie matrix-u do leave-one-out w celu 'score' do rankingu
'''
def compute_item_embeddings(model, movie_loader):
    model.eval()
    all_embs = []
    with torch.no_grad():
        for mb in movie_loader:
            mb = to_device(mb, device)

            embs = model.item_tower(mb, key='pos_item')  # [batch_size, D]
            all_embs.append(embs)
    return torch.cat(all_embs, dim=0).cpu().numpy()  # [n_movies, D]

In [544]:
'''
Definicja loss-u BPR (Bayesian Personalized Ranking)
'''
def bpr_loss(u, i_pos, i_neg):
    pos = (u*i_pos).sum(1) # [B] score pozytywnych par
    neg = (u*i_neg).sum(1)
    return -torch.log(torch.sigmoid(pos-neg) + 1e-8).mean()

In [545]:
'''
Trenowanie jednej epoki, dodano odpowiednie inputy tez do testow i ewentualnych zmian

Obecnie:
- model: TwoTowerModel
- loader: DataLoader
- optimizer: Adam
- loss: bpr_loss
'''
def train_one_epoch(model, loader, optimizer):
    model.train()
    running_loss = 0.0

    for raw in tqdm(train_loader, desc=f" Epoch {epoch} batches", leave=False):
        batch = to_device(raw, device)
        optimizer.zero_grad()

        user_vec, pos_vec, neg_vec = model(batch) # forward do TwoTowerModel

        loss = bpr_loss(user_vec, pos_vec, neg_vec)

        loss.backward() # Backword i updatujemy parametry
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss/len(loader) # Do wyliczania sredniej straty w epoce
    return epoch_loss

In [546]:
'''
Lekka ewaluacja majaca za zadanie pokazac czy model sie uczy, niz odpowiadac jak dobrze tworzy ranking
'''
def light_evaluate(model, loader):
    model.eval()
    aucs, paac = [], []

    with torch.no_grad():
        for raw in loader:
            batch = to_device(raw, device)

            user_vec, pos_vec, neg_vec = model(batch)

            pos_score = (user_vec * pos_vec).sum(dim = -1) # [B]
            neg_score = (user_vec * neg_vec).sum(dim = -1)

            # ROC AUC
            labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
            scores = torch.cat([pos_score, neg_score])
            aucs.append(roc_auc_score(labels.cpu(), scores.cpu()))

            # Pair-wise accuarcy
            paac.append((pos_score > neg_score).float().mean().item())

    return float(np.mean(aucs)), float(np.mean(paac))

In [547]:
'''
Dokladniejsza ewaluacja majaca odpowiedziec jak model radzi sobie z rankingiem dla danych uzytkownikow
'''
def heavy_evaluate(model,user_loader,item_embs_np,
                        train_pos_sets,test_pos,top_N):
    model.eval()
    user_embs = []

    with torch.no_grad():
        for raw in user_loader:
            batch = to_device(raw, device)

            u = model.user_tower(batch['user'])  # Skupiamy sie tylko na zebraniu embeddingow uzytkownika

            user_embs.append(u.cpu().numpy())

    user_embs = np.vstack(user_embs)    # [U-liczba uzytkownikow, D]

    assert len(all_user_ids) == user_embs.shape[0]
    recalls, mrrs = [], []

    for idx, user_id in enumerate(all_user_ids):
        vec = user_embs[idx]                # [D] wektor emb usera
        scores = item_embs_np @ vec         # [I] wektory score, do oceny czy to dziala poprawnie ? 'iloczyny skalarne'

        mask = np.zeros_like(scores, dtype=bool)
        mask[list(train_pos_sets[user_id])] = True  # Tworzymy maske do odsiania filmow ktore user juz widzial
        scores[mask] = -1e9

        ranked = np.argsort(-scores)[:top_N]    # Ranking
        true_set = test_pos[user_id]            # hold-out

        # Recall@K
        recalls.append(int(any(r in true_set for r in ranked)))

        # MRR@K
        rr = 0.0
        for rank, idx in enumerate(ranked, 1):
            if idx in true_set:
                rr = 1.0/rank
                break
        mrrs.append(rr)

    return float(np.mean(recalls)), float(np.mean(mrrs))

In [548]:
'''
Early stopping
'''
best_val = 0.0                    # dla metryk, które chcemy maksymalizować (np. ROC-AUC)
epochs_no_improve = 0
patience = 5                      # maksymalna liczba epok bez poprawy
save_path = "best_model.pt"       # gdzie będziemy dumpować najlepszy model

In [549]:
EPOCHS = 50
TOP_N = 20

model = (TwoTowerModel(stats_dim=25,
                       n_items=n_items,
                       vocab_sizes=(num_actors, num_directors, num_genres),
                       dense_feat_dim=24,
                       text_emb_dim=300,
                       embedding_dim=EMB_DIM)
         .to(device))
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS) # zmieniamy LR zgodnie z kosinusem (powinno stabilizowac trening)

for epoch in trange(1, EPOCHS+1, desc="Epochs"):

    tr_loss = train_one_epoch(model, train_loader, optimizer) # Logika treningu

    scheduler.step() # optymalizacja LR

    print(f"Epoch {epoch:2d} | train_loss={tr_loss:.4f}")
    if epoch % 2 == 0:

        movie_matrix_np = compute_item_embeddings(model, movie_loader) # [n_movies, D] wyliczamy embeedingi filmow
        D = movie_matrix_np.shape[1]

        faiss_index = faiss.IndexFlatIP(D) # Nowy indeks pod FAISS
        faiss_index.add(movie_matrix_np)

        auc, pair_acc = light_evaluate(model, val_light_loader)
        print(f"LIGHT eval | val ROC-AUC={auc:.4f} | pair-acc={pair_acc:.4f}")

        if auc > best_val + 1e-4:
            best_val = auc
            epochs_no_improve = 0
            torch.save(model.state_dict(), save_path)
            print(f"  poprawa! zapisano model (ROC-AUC={best_val:.4f})")
        else:
            epochs_no_improve += 1
            print(f"  brak poprawy ({epochs_no_improve}/{patience})")

    if epoch % 5 == 0:
        recall, mrr = heavy_evaluate(
            model,
            val_heavy_loader,               # loader zwracający tylko user embeddings
            movie_matrix_np,                # matrix do score-a
            train_pos_sets_val,
            test_pos_val,
            top_N=TOP_N
        )
        print(f"HEAVY eval | @K={TOP_N}: Recall@{TOP_N}={recall:.4f}, MRR@{TOP_N}={mrr:.4f}")

    # ---- sprawdzamy early stopping ----
    if epochs_no_improve >= patience:
        print(f"\nEarly stopping — przez {patience} epok nie było lepszego ROC-AUC.")
        break

model.load_state_dict(torch.load(save_path))
model.eval()

Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

 Epoch 1 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  1 | train_loss=0.6452


 Epoch 2 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  2 | train_loss=0.5957
LIGHT eval | val ROC-AUC=0.6949 | pair-acc=0.7067
  poprawa! zapisano model (ROC-AUC=0.6949)


 Epoch 3 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  3 | train_loss=0.5795


 Epoch 4 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  4 | train_loss=0.5737
LIGHT eval | val ROC-AUC=0.7048 | pair-acc=0.7188
  poprawa! zapisano model (ROC-AUC=0.7048)


 Epoch 5 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  5 | train_loss=0.5697
HEAVY eval | @K=20: Recall@20=0.0223, MRR@20=0.0052


 Epoch 6 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  6 | train_loss=0.5665
LIGHT eval | val ROC-AUC=0.7089 | pair-acc=0.7201
  poprawa! zapisano model (ROC-AUC=0.7089)


 Epoch 7 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  7 | train_loss=0.5643


 Epoch 8 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  8 | train_loss=0.5625
LIGHT eval | val ROC-AUC=0.7146 | pair-acc=0.7288
  poprawa! zapisano model (ROC-AUC=0.7146)


 Epoch 9 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch  9 | train_loss=0.5601


 Epoch 10 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 10 | train_loss=0.5605
LIGHT eval | val ROC-AUC=0.7164 | pair-acc=0.7296
  poprawa! zapisano model (ROC-AUC=0.7164)
HEAVY eval | @K=20: Recall@20=0.0208, MRR@20=0.0051


 Epoch 11 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 11 | train_loss=0.5585


 Epoch 12 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 12 | train_loss=0.5581
LIGHT eval | val ROC-AUC=0.7204 | pair-acc=0.7338
  poprawa! zapisano model (ROC-AUC=0.7204)


 Epoch 13 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 13 | train_loss=0.5571


 Epoch 14 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 14 | train_loss=0.5562
LIGHT eval | val ROC-AUC=0.7212 | pair-acc=0.7330
  poprawa! zapisano model (ROC-AUC=0.7212)


 Epoch 15 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 15 | train_loss=0.5565
HEAVY eval | @K=20: Recall@20=0.0160, MRR@20=0.0037


 Epoch 16 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 16 | train_loss=0.5550
LIGHT eval | val ROC-AUC=0.7206 | pair-acc=0.7352
  brak poprawy (1/5)


 Epoch 17 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 17 | train_loss=0.5546


 Epoch 18 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 18 | train_loss=0.5541
LIGHT eval | val ROC-AUC=0.7198 | pair-acc=0.7339
  brak poprawy (2/5)


 Epoch 19 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 19 | train_loss=0.5541


 Epoch 20 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 20 | train_loss=0.5529
LIGHT eval | val ROC-AUC=0.7245 | pair-acc=0.7386
  poprawa! zapisano model (ROC-AUC=0.7245)
HEAVY eval | @K=20: Recall@20=0.0162, MRR@20=0.0038


 Epoch 21 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 21 | train_loss=0.5518


 Epoch 22 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 22 | train_loss=0.5535
LIGHT eval | val ROC-AUC=0.7251 | pair-acc=0.7378
  poprawa! zapisano model (ROC-AUC=0.7251)


 Epoch 23 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 23 | train_loss=0.5518


 Epoch 24 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 24 | train_loss=0.5510
LIGHT eval | val ROC-AUC=0.7262 | pair-acc=0.7421
  poprawa! zapisano model (ROC-AUC=0.7262)


 Epoch 25 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 25 | train_loss=0.5509
HEAVY eval | @K=20: Recall@20=0.0172, MRR@20=0.0041


 Epoch 26 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 26 | train_loss=0.5513
LIGHT eval | val ROC-AUC=0.7277 | pair-acc=0.7429
  poprawa! zapisano model (ROC-AUC=0.7277)


 Epoch 27 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 27 | train_loss=0.5508


 Epoch 28 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 28 | train_loss=0.5484
LIGHT eval | val ROC-AUC=0.7251 | pair-acc=0.7408
  brak poprawy (1/5)


 Epoch 29 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 29 | train_loss=0.5501


 Epoch 30 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 30 | train_loss=0.5508
LIGHT eval | val ROC-AUC=0.7249 | pair-acc=0.7410
  brak poprawy (2/5)
HEAVY eval | @K=20: Recall@20=0.0191, MRR@20=0.0043


 Epoch 31 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 31 | train_loss=0.5498


 Epoch 32 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 32 | train_loss=0.5489
LIGHT eval | val ROC-AUC=0.7284 | pair-acc=0.7446
  poprawa! zapisano model (ROC-AUC=0.7284)


 Epoch 33 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 33 | train_loss=0.5507


 Epoch 34 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 34 | train_loss=0.5502
LIGHT eval | val ROC-AUC=0.7294 | pair-acc=0.7450
  poprawa! zapisano model (ROC-AUC=0.7294)


 Epoch 35 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 35 | train_loss=0.5483
HEAVY eval | @K=20: Recall@20=0.0153, MRR@20=0.0035


 Epoch 36 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 36 | train_loss=0.5491
LIGHT eval | val ROC-AUC=0.7300 | pair-acc=0.7444
  poprawa! zapisano model (ROC-AUC=0.7300)


 Epoch 37 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 37 | train_loss=0.5487


 Epoch 38 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 38 | train_loss=0.5485
LIGHT eval | val ROC-AUC=0.7276 | pair-acc=0.7438
  brak poprawy (1/5)


 Epoch 39 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 39 | train_loss=0.5485


 Epoch 40 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 40 | train_loss=0.5486
LIGHT eval | val ROC-AUC=0.7303 | pair-acc=0.7467
  poprawa! zapisano model (ROC-AUC=0.7303)
HEAVY eval | @K=20: Recall@20=0.0137, MRR@20=0.0031


 Epoch 41 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 41 | train_loss=0.5469


 Epoch 42 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 42 | train_loss=0.5475
LIGHT eval | val ROC-AUC=0.7270 | pair-acc=0.7428
  brak poprawy (1/5)


 Epoch 43 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 43 | train_loss=0.5475


 Epoch 44 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 44 | train_loss=0.5474
LIGHT eval | val ROC-AUC=0.7276 | pair-acc=0.7425
  brak poprawy (2/5)


 Epoch 45 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 45 | train_loss=0.5476
HEAVY eval | @K=20: Recall@20=0.0143, MRR@20=0.0034


 Epoch 46 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 46 | train_loss=0.5478
LIGHT eval | val ROC-AUC=0.7306 | pair-acc=0.7447
  poprawa! zapisano model (ROC-AUC=0.7306)


 Epoch 47 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 47 | train_loss=0.5474


 Epoch 48 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 48 | train_loss=0.5467
LIGHT eval | val ROC-AUC=0.7309 | pair-acc=0.7476
  poprawa! zapisano model (ROC-AUC=0.7309)


 Epoch 49 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 49 | train_loss=0.5472


 Epoch 50 batches:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 50 | train_loss=0.5472
LIGHT eval | val ROC-AUC=0.7297 | pair-acc=0.7449
  brak poprawy (1/5)
HEAVY eval | @K=20: Recall@20=0.0140, MRR@20=0.0033


  model.load_state_dict(torch.load(save_path))


TwoTowerModel(
  (user_tower): UserTower(
    (item_emb): Embedding(82911, 64)
    (rating_proj): Linear(in_features=2, out_features=1, bias=True)
    (mlp): Sequential(
      (0): Linear(in_features=89, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=512, out_features=384, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.2, inplace=False)
      (6): Linear(in_features=384, out_features=256, bias=True)
      (7): ReLU()
      (8): Dropout(p=0.2, inplace=False)
      (9): Linear(in_features=256, out_features=64, bias=True)
    )
  )
  (item_tower): ItemTower(
    (actor_emb): Embedding(11465, 64)
    (director_emb): Embedding(5163, 64)
    (genre_emb): Embedding(20, 64)
    (meta_mlp): Sequential(
      (0): Linear(in_features=24, out_features=128, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.2, inplace=False)
      (3): Linear(in_features=128, out_features=64, bias=True)
      (4): ReLU()
    )
    (text_ml

In [550]:
torch.save(model.user_tower.state_dict(), 'user_tower.pth')
torch.save(model.item_tower.state_dict(), 'item_tower.pth')