In [208]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset

from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np

from datetime import datetime
from tqdm import tqdm

import random

from sklearn.model_selection import train_test_split

In [209]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [210]:
df_users = pd.read_parquet('user_features_clean.parquet')
df_movies = pd.read_parquet('Movies_clean_Vec_v4_25keywords.parquet')
df_ratings = pd.read_parquet('ratings_groupped_ids.parquet')

# Przygotowanie movieId dla datasetów

In [211]:
print(df_users.info())
print(df_ratings.info())
print(df_movies.info())

empty_pos_ratings = df_ratings['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = df_ratings['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 198832 entries, 0 to 198831
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   198832 non-null  int64  
 1   num_rating               198832 non-null  float64
 2   avg_rating               198832 non-null  float64
 3   weekend_watcher          198832 non-null  float64
 4   genre_Action             198832 non-null  float64
 5   genre_Adventure          198832 non-null  float64
 6   genre_Animation          198832 non-null  float64
 7   genre_Comedy             198832 non-null  float64
 8   genre_Crime              198832 non-null  float64
 9   genre_Documentary        198832 non-null  float64
 10  genre_Drama              198832 non-null  float64
 11  genre_Family             198832 non-null  float64
 12  genre_Fantasy            198832 non-null  float64
 13  genre_History            198832 non-null  float64
 14  genr

In [212]:
unique_ids = set(
        df_users['movies_seq'].explode().tolist()
        + df_ratings['pos'].explode().tolist() 
        + df_ratings['neg'].explode().tolist()
    )

print('Unique movieIds:', len(unique_ids))
unique_ids = sorted(unique_ids)

movieId_to_idx = {id_: idx for idx, id_ in enumerate(unique_ids)}
print('min idx:', min(movieId_to_idx.values()))
print('max idx:', max(movieId_to_idx.values()))

n_items = len(unique_ids)

assert min(movieId_to_idx.values()) == 0
assert max(movieId_to_idx.values()) == n_items - 1

# unique_ids = sorted(df_movies['movieId'].unique())
# movieId_to_idx = {id_: idx for idx, id_ in enumerate(unique_ids)}
# n_items = len(movieId_to_idx)


Unique movieIds: 82932
min idx: 0
max idx: 82931


In [213]:
# Zmapuj movieId do indeksów
df_users['movies_seq'] = df_users['movies_seq'].apply(lambda lst: [movieId_to_idx[m] for m in lst])
df_ratings['pos'] = df_ratings['pos'].apply(lambda lst: [movieId_to_idx[m] for m in lst])
df_ratings['neg'] = df_ratings['neg'].apply(lambda lst: [movieId_to_idx[m] for m in lst])

# df_movies musi być ograniczone tylko do używanych filmów
df_movies = df_movies[df_movies['movieId'].isin(movieId_to_idx)]
df_movies['movie_idx'] = df_movies['movieId'].map(movieId_to_idx)

# Final sanity check
assert df_users['movies_seq'].explode().max() < n_items
assert df_ratings['pos'].explode().max() < n_items
assert df_ratings['neg'].explode().max() < n_items
assert df_movies['movie_idx'].max() < n_items
assert df_movies['movie_idx'].notna().all(), "Some movieIds weren't mapped!"


# df_users['movies_seq'] = df_users['movies_seq'].apply(lambda lst: [movieId_to_idx[m] for m in lst if m in movieId_to_idx])
# df_ratings['pos'] = df_ratings['pos'].apply(lambda lst: [movieId_to_idx[m] for m in lst if m in movieId_to_idx])
# df_ratings['neg'] = df_ratings['neg'].apply(lambda lst: [movieId_to_idx[m] for m in lst if m in movieId_to_idx])
# 
# assert df_users['movies_seq'].explode().max() < n_items

In [214]:
max_movie_idx = df_users['movies_seq'].explode().max()
print("max_movie_idx =", max_movie_idx)
print("n_items =", n_items)

assert max_movie_idx < n_items, "Indeks filmu przekracza rozmiar embeddingu"

max_movie_idx = 82931
n_items = 82932


In [215]:
def has_invalid_entries(seq_col):
    return seq_col.explode().isin([-1, np.nan, None]).any()

print("Zawiera niepoprawne wartości:", has_invalid_entries(df_users['movies_seq']))


Zawiera niepoprawne wartości: False


In [216]:
df_movies.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 82918 entries, 0 to 82917
Data columns (total 30 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   movieId              82918 non-null  int64  
 1   runtime              82918 non-null  float64
 2   if_blockbuster       82918 non-null  int64  
 3   highly_watched       82918 non-null  int64  
 4   highly_rated         82918 non-null  int64  
 5   engagement_score     82918 non-null  float64
 6   cast_importance      82918 non-null  float64
 7   director_score       82918 non-null  float64
 8   has_keywords         82918 non-null  int64  
 9   has_cast             82918 non-null  int64  
 10  has_director         82918 non-null  int64  
 11  genre_ids            82918 non-null  object 
 12  decade_[1890, 1900)  82918 non-null  bool   
 13  decade_[1900, 1910)  82918 non-null  bool   
 14  decade_[1910, 1920)  82918 non-null  bool   
 15  decade_[1920, 1930)  82918 non-null 

In [217]:
#FOR QUICK TEST's

DEBUG = False

if DEBUG:
    df_users = df_users.sample(n=256, random_state=42).copy()
    df_ratings = df_ratings[df_ratings['userId'].isin(df_users['userId'])].copy()


# Przygotowanie danych (Item Tower)

In [218]:
def prepare_feature_tensor(df_movies: pd.DataFrame):
    import ast

    for col in ['text_embedded', 'genre_ids', 'actor_ids', 'director_ids']:
        df_movies[col] = df_movies[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)

    # Mapowania ID dla Embeedings
    all_actor_ids = set(i for sub in df_movies['actor_ids'] for i in sub)
    all_director_ids = set(i for sub in df_movies['director_ids'] for i in sub)
    all_genre_ids = set(i for sub in df_movies['genre_ids'] for i in sub)

    actor_id_map = {aid: idx for idx, aid in enumerate(sorted(all_actor_ids))}
    director_id_map = {did: idx for idx, did in enumerate(sorted(all_director_ids))}
    genre_id_map = {gid: idx for idx, gid in enumerate(sorted(all_genre_ids))}
    
    # Map the raw IDs to internal indices
    df_movies['actor_ids'] = df_movies['actor_ids'].apply(lambda lst: [actor_id_map[i] for i in lst])
    df_movies['director_ids'] = df_movies['director_ids'].apply(lambda lst: [director_id_map[i] for i in lst])
    df_movies['genre_ids'] = df_movies['genre_ids'].apply(lambda lst: [genre_id_map[i] for i in lst])

    # EmbeddingBag
    def make_bag_inputs(id_lists):
        flat = []
        offsets = [0]
        for lst in id_lists:
            flat.extend(lst)
            offsets.append(len(flat))
        return torch.tensor(flat, dtype=torch.long), torch.tensor(offsets[:-1], dtype=torch.long)

    actor_idx_bag, actor_offsets = make_bag_inputs(df_movies['actor_ids'])
    director_idx_bag, director_offsets = make_bag_inputs(df_movies['director_ids'])
    genre_idx_bag, genre_offsets = make_bag_inputs(df_movies['genre_ids'])

    text_tensor = np.stack(df_movies['text_embedded'].apply(np.array).to_list())

    numeric_cols = ['runtime', 'engagement_score', 'cast_importance', 'director_score']
    binary_cols = ['if_blockbuster', 'highly_watched', 'highly_rated', 'has_keywords', 'has_cast', 'has_director']
    decade_cols = [col for col in df_movies.columns if col.startswith("decade_")]

    num_bin_tensor = df_movies[numeric_cols + binary_cols + decade_cols].astype(np.float32).values
    full_features = np.hstack([num_bin_tensor, text_tensor])
    features_tensor = torch.tensor(full_features, dtype=torch.float32)

    if torch.isnan(features_tensor).any():
        print("NaN in feature tensor!")
        features_tensor = torch.nan_to_num(features_tensor)

    # assert actor_idx_bag.max().item() < num_actors, "Actor ID exceeds num_actors"
    
    return (features_tensor,
            actor_idx_bag, actor_offsets,
            director_idx_bag, director_offsets,
            genre_idx_bag, genre_offsets,
            len(actor_id_map), len(director_id_map), len(genre_id_map))

In [219]:
def get_embedding_bag_inputs(indices, bag_tensor, offset_tensor):
    new_offsets = []
    new_bag = []
    offset = 0
    for i in indices:
        i = i.item()
        start = offset_tensor[i].item()
        end = offset_tensor[i + 1].item() if i + 1 < len(offset_tensor) else len(bag_tensor)
        segment = bag_tensor[start:end]
        new_bag.extend(segment.tolist())
        new_offsets.append(offset)
        offset += len(segment)
    return torch.tensor(new_bag, dtype=torch.long), torch.tensor(new_offsets, dtype=torch.long)

# Przygotowanie danych (User Tower)

In [220]:
class UserDataset(Dataset):
    def __init__(self, df_users):
        self.data = df_users
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data.iloc[idx]

In [221]:
n_items = len(unique_ids)

def collate_fn(batch):
    df_users, movies, ratings, timestamps, pos, neg = [], [], [], [], [], []

    for row in batch:
        movies.append(torch.tensor(row['movies_seq'], dtype=torch.long))
        ratings.append(torch.tensor(row['ratings_seq'], dtype=torch.float32))
        timestamps.append(torch.tensor(row['ts_seq'], dtype=torch.float32))

        userId = row['userId']

        r = row[['num_rating', 'avg_rating', 'weekend_watcher', 'genre_Action', 'genre_Adventure', 'genre_Animation', 'genre_Comedy', 'genre_Crime', 'genre_Documentary', 'genre_Drama', 'genre_Family', 'genre_Fantasy', 'genre_History', 'genre_Horror', 'genre_Music', 'genre_Mystery', 'genre_Romance', 'genre_Science Fiction', 'genre_TV Movie', 'genre_Thriller', 'genre_War', 'genre_Western', 'type_of_viewer_negative', 'type_of_viewer_neutral', 'type_of_viewer_positive']]
        r = r.astype('float32').values

        df_users.append(torch.tensor(r, dtype=torch.float32))
        
        # Get a random movieId that was rated positively and one that was rated negatively. 
        # Used during training to calculate BPR loss. 
        posAndNegRow = df_ratings[df_ratings['userId'] == userId].iloc[0]
        pos.append(torch.tensor(random.choice(posAndNegRow['pos']), dtype=torch.long))
        neg.append(torch.tensor(random.choice(posAndNegRow['neg']), dtype=torch.long))

    return {
        "input": {
            "df_users": torch.stack(df_users),
            "movies": torch.stack(movies),
            "ratings": torch.stack(ratings),
            "timestamps": torch.stack(timestamps),
        },
        "pos": torch.as_tensor(pos, dtype=torch.long),
        "neg": torch.as_tensor(neg, dtype=torch.long)
    }

# Model (Item Tower)

In [222]:
class ItemTower(nn.Module):
    def __init__(self, input_dim, embedding_dim=64, num_actors=10000, num_directors=5000, num_genres=19):
        super(ItemTower, self).__init__()
        self.actor_embedding = nn.EmbeddingBag(num_actors, 32, mode='mean')
        self.director_embedding = nn.EmbeddingBag(num_directors, 32, mode='mean')
        self.genre_embedding = nn.EmbeddingBag(num_genres, 16, mode='mean')

        self.model = nn.Sequential(
            nn.Linear(input_dim + 32 + 32 + 16, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, x, actor_bag, actor_offsets,
                      director_bag, director_offsets,
                      genre_bag, genre_offsets):
        actor_emb = self.actor_embedding(actor_bag, actor_offsets)
        director_emb = self.director_embedding(director_bag, director_offsets)
        genre_emb = self.genre_embedding(genre_bag, genre_offsets)

        x = torch.cat([x, actor_emb, director_emb, genre_emb], dim=1)
        return self.model(x)

# Model (User Tower)

In [223]:
class UserTower(nn.Module):
    def __init__(self, input_dim, embedding_dim=64, n_items=1000):
        super(UserTower, self).__init__()

        # Item Embeddings for User History
        self.item_emb = nn.Embedding(n_items, embedding_dim)
        
        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights from rating and timestamp
        x = torch.stack([batch['ratings'], batch['timestamps']], dim=-1)
        w = torch.sigmoid(self.rating_proj(x))

        # weighted mean-pool
        pooled = (m * w).sum(1) / (w.sum(1).clamp_min(1e-6))

        input = torch.cat([batch['df_users'], pooled], dim=-1)
        output = self.mlp(input)
        u = F.normalize(output, dim=1)
        return u

# Evaluation

In [224]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

def evaluate_two_tower_model_batched(
    user_tower, item_tower,
    df_users, df_ratings, df_movies,
    movie_id_to_idx,
    movie_features, actor_idx_bag, actor_offsets,
    director_idx_bag, director_offsets,
    genre_idx_bag, genre_offsets,
    top_k=10, max_users=500, batch_size=32,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
):

    user_tower.eval()
    item_tower.eval()

    # PRECOMPUTED
    with torch.no_grad():
        item_embeddings = item_tower(
            movie_features.to(device),
            actor_idx_bag.to(device), actor_offsets.to(device),
            director_idx_bag.to(device), director_offsets.to(device),
            genre_idx_bag.to(device), genre_offsets.to(device)
        )
        item_embeddings = F.normalize(item_embeddings, dim=1)

    print("Evaluating users in batches...")
    user_ids = df_users['userId'].unique()
    if len(user_ids) > max_users:
        user_ids = np.random.choice(user_ids, size=max_users, replace=False)

    metrics = {'Precision@K': [], 'Recall@K': [], 'MRR': [], 'nDCG@K': []}

    def precision_at_k(true_item, recommended, k): return int(true_item in recommended[:k]) / k
    
    def recall_at_k(true_item, recommended, k): return int(true_item in recommended[:k])
    
    def mrr(true_item, recommended): return 1 / (recommended.index(true_item) + 1) if true_item in recommended else 0
    
    def ndcg_at_k(true_item, recommended, k):
        if true_item in recommended[:k]:
            rank = recommended.index(true_item)
            return 1 / np.log2(rank + 2)
        return 0.0

    for i in tqdm(range(0, len(user_ids), batch_size)):
        batch_user_ids = user_ids[i:i+batch_size]
        batch_inputs = {'movies': [], 'ratings': [], 'timestamps': [], 'df_users': []}
        held_out_items = []
        history_indices = []

        for user_id in batch_user_ids:
            row_u = df_users[df_users['userId'] == user_id].iloc[0]
            row_r = df_ratings[df_ratings['userId'] == user_id].iloc[0]
            pos_movies = row_r['pos']
            if len(pos_movies) < 2:
                continue

            held_out = pos_movies[-1]
            history = pos_movies[:-1]

            indices = [movie_id_to_idx[mid] for mid in history if mid in movie_id_to_idx]
            if not indices:
                continue

            batch_inputs['movies'].append(torch.tensor(indices, dtype=torch.long))
            batch_inputs['ratings'].append(torch.tensor([4.0]*len(indices), dtype=torch.float32))
            batch_inputs['timestamps'].append(torch.tensor([1.0]*len(indices), dtype=torch.float32))
            batch_inputs['df_users'].append(torch.tensor(row_u.drop(['userId', 'movies_seq', 'ratings_seq', 'ts_seq']).values.astype(np.float32)))

            held_out_items.append(held_out)
            history_indices.append(indices)

        if not held_out_items:
            continue

        max_len = max(len(seq) for seq in batch_inputs['movies'])
        for key in ['movies', 'ratings', 'timestamps']:
            batch_inputs[key] = torch.stack([
                F.pad(seq, (0, max_len - len(seq)), value=0) for seq in batch_inputs[key]
            ])

        batch_inputs['df_users'] = torch.stack(batch_inputs['df_users']).to(device)
        batch_inputs['movies'] = batch_inputs['movies'].to(device)
        batch_inputs['ratings'] = batch_inputs['ratings'].to(device)
        batch_inputs['timestamps'] = batch_inputs['timestamps'].to(device)

        with torch.no_grad():
            user_vecs = user_tower(batch_inputs)
            scores = torch.matmul(user_vecs, item_embeddings.T)

            for j, user_score in enumerate(scores):
                # Mask history
                user_score[history_indices[j]] = -1e9
                top_k_items = torch.topk(user_score, k=top_k).indices.tolist()
                top_k_movie_ids = [df_movies.iloc[x]['movieId'] for x in top_k_items]

                true_item = held_out_items[j]
                metrics['Precision@K'].append(precision_at_k(true_item, top_k_movie_ids, top_k))
                metrics['Recall@K'].append(recall_at_k(true_item, top_k_movie_ids, top_k))
                metrics['MRR'].append(mrr(true_item, top_k_movie_ids))
                metrics['nDCG@K'].append(ndcg_at_k(true_item, top_k_movie_ids, top_k))

    return {k: np.mean(v) if v else 0.0 for k, v in metrics.items()}

# TRAINING

In [225]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')

print('Device:', device)

Device: cuda


In [226]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device)
    else:
        return data

In [227]:
BATCH_SIZE = 4096 #DEBUG: 32

from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df_users, test_size=0.2)

movie_features, actor_idx_bag, actor_offsets, director_idx_bag, director_offsets, genre_idx_bag, genre_offsets, num_actors, num_directors, num_genres = prepare_feature_tensor(df_movies)

trainDataset = UserDataset(train_df)
trainDataLoader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

testDataset = UserDataset(test_df)
testDataLoader = DataLoader(testDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [228]:
EMB_DIM = 128

user_tower = UserTower(input_dim=25, n_items=n_items, embedding_dim=EMB_DIM).to(device)
item_tower = ItemTower(
    input_dim=movie_features.shape[1],
    embedding_dim=EMB_DIM,
    num_actors=num_actors,
    num_directors=num_directors,
    num_genres=num_genres
).to(device)


params = list(user_tower.parameters()) + list(item_tower.parameters())
optimizer = optim.Adam(params, lr=1e-3)

In [229]:
def train_one_epoch_two_tower(user_tower, item_tower, data_loader, optimizer, device, movie_features,
                              actor_idx_bag, actor_offsets,
                              director_idx_bag, director_offsets,
                              genre_idx_bag, genre_offsets):
    
    user_tower.train()
    item_tower.train()
    running_loss = 0.0
    total = 0
    
    movie_features = movie_features.to(device)
    actor_idx_bag = actor_idx_bag.to(device)
    actor_offsets = actor_offsets.to(device)
    director_idx_bag = director_idx_bag.to(device)
    director_offsets = director_offsets.to(device)
    genre_idx_bag = genre_idx_bag.to(device)
    genre_offsets = genre_offsets.to(device)

    for batch in data_loader:
        batch = to_device(batch, device)
        optimizer.zero_grad()

        user_vec = user_tower(batch['input'])

        actor_pos_bag, actor_pos_offsets = get_embedding_bag_inputs(batch['pos'], actor_idx_bag, actor_offsets)
        director_pos_bag, director_pos_offsets = get_embedding_bag_inputs(batch['pos'], director_idx_bag, director_offsets)
        genre_pos_bag, genre_pos_offsets = get_embedding_bag_inputs(batch['pos'], genre_idx_bag, genre_offsets)

        actor_neg_bag, actor_neg_offsets = get_embedding_bag_inputs(batch['neg'], actor_idx_bag, actor_offsets)
        director_neg_bag, director_neg_offsets = get_embedding_bag_inputs(batch['neg'], director_idx_bag, director_offsets)
        genre_neg_bag, genre_neg_offsets = get_embedding_bag_inputs(batch['neg'], genre_idx_bag, genre_offsets)
        
        #FOR DEBUGGING
        # print("max actor id in batch:", actor_pos_bag.max().item(), "num_actors:", item_tower.actor_embedding.num_embeddings)

        pos_vec = item_tower(movie_features[batch['pos']].to(device), actor_pos_bag.to(device), actor_pos_offsets.to(device),
                             director_pos_bag.to(device), director_pos_offsets.to(device),
                             genre_pos_bag.to(device), genre_pos_offsets.to(device))
        
        neg_vec = item_tower(movie_features[batch['neg']].to(device), actor_neg_bag.to(device), actor_neg_offsets.to(device),
                             director_neg_bag.to(device), director_neg_offsets.to(device),
                             genre_neg_bag.to(device), genre_neg_offsets.to(device))


        pos_score = (user_vec * pos_vec).sum(dim=-1)
        neg_score = (user_vec * neg_vec).sum(dim=-1)

        loss = -F.logsigmoid(pos_score - neg_score).mean()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += 1

    return running_loss / total

In [None]:
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import roc_auc_score

EPOCHS = 50
EVAL_EVERY = 5
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

for epoch in tqdm(range(EPOCHS), desc="Training Two-Tower"):
    avg_loss = train_one_epoch_two_tower(
        user_tower=user_tower,
        item_tower=item_tower,
        data_loader=trainDataLoader,
        optimizer=optimizer,
        device=device,
        movie_features = movie_features,
        actor_idx_bag=actor_idx_bag,
        actor_offsets=actor_offsets,
        director_idx_bag=director_idx_bag,
        director_offsets=director_offsets,
        genre_idx_bag=genre_idx_bag,
        genre_offsets=genre_offsets
    )
    
    print(f"[Epoch {epoch + 1}] | Loss: {avg_loss:.4f}")

    if epoch % EVAL_EVERY == (EVAL_EVERY - 1):
        user_tower.eval()
        item_tower.eval()

        aucs, pair_accs = [], []
        
        movie_features = movie_features.to(device)
        actor_idx_bag = actor_idx_bag.to(device)
        actor_offsets = actor_offsets.to(device)
        director_idx_bag = director_idx_bag.to(device)
        director_offsets = director_offsets.to(device)
        genre_idx_bag = genre_idx_bag.to(device)
        genre_offsets = genre_offsets.to(device)

        with torch.no_grad():
            item_emb = item_tower(
                movie_features.to(device),
                actor_idx_bag.to(device),
                actor_offsets.to(device),
                director_idx_bag.to(device),
                director_offsets.to(device),
                genre_idx_bag.to(device),
                genre_offsets.to(device)
            ).cpu().detach().numpy()

            for batch in testDataLoader:
                batch = to_device(batch, device)

                u = user_tower(batch['input'])

                actor_pos_bag, actor_pos_offsets = get_embedding_bag_inputs(batch['pos'], actor_idx_bag, actor_offsets)
                director_pos_bag, director_pos_offsets = get_embedding_bag_inputs(batch['pos'], director_idx_bag,director_offsets)
                genre_pos_bag, genre_pos_offsets = get_embedding_bag_inputs(batch['pos'], genre_idx_bag,genre_offsets)

                actor_neg_bag, actor_neg_offsets = get_embedding_bag_inputs(batch['neg'], actor_idx_bag,actor_offsets)
                director_neg_bag, director_neg_offsets = get_embedding_bag_inputs(batch['neg'], director_idx_bag, director_offsets)
                genre_neg_bag, genre_neg_offsets = get_embedding_bag_inputs(batch['neg'], genre_idx_bag, genre_offsets)

                pos_vec = item_tower(movie_features[batch['pos']].to(device), actor_pos_bag.to(device), actor_pos_offsets.to(device),
                                     director_pos_bag.to(device), director_pos_offsets.to(device),
                                     genre_pos_bag.to(device), genre_pos_offsets.to(device))
                
                neg_vec = item_tower(movie_features[batch['neg']].to(device), actor_neg_bag.to(device), actor_neg_offsets.to(device),
                                     director_neg_bag.to(device), director_neg_offsets.to(device),
                                     genre_neg_bag.to(device), genre_neg_offsets.to(device))

                pos_score = (u * pos_vec).sum(dim=-1)
                neg_score = (u * neg_vec).sum(dim=-1)

                labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
                scores = torch.cat([pos_score, neg_score])

                aucs.append(roc_auc_score(labels.cpu(), scores.cpu()))
                pair_accs.append((pos_score > neg_score).float().mean().item())

        print(f"[Epoch {epoch + 1}] Pointwise Eval:")
        print(f"  ROC AUC:       {np.mean(aucs):.4f}")
        print(f"  Pairwise Acc:  {np.mean(pair_accs):.4f}")

        rank_metrics = evaluate_two_tower_model_batched(
            user_tower=user_tower,
            item_tower=item_tower,
            df_users=test_df,           # ← has user profiles & sequences
            df_ratings=df_ratings,      # ← has pos/neg lists
            df_movies=df_movies,        # ← all movie features
            movie_id_to_idx = df_movies.set_index('movieId')['movie_idx'].to_dict(),
            top_k=10,
            max_users=1000, #DEBUG: 32
            batch_size=32, #DEBUG: 8
            device=device,
            movie_features=movie_features,
            actor_idx_bag=actor_idx_bag,
            actor_offsets=actor_offsets,
            director_idx_bag=director_idx_bag,
            director_offsets=director_offsets,
            genre_idx_bag=genre_idx_bag,
            genre_offsets=genre_offsets
        )

        print(f"[Epoch {epoch + 1}] Retrieval Eval:")
        print(f"  Precision@K:   {rank_metrics['Precision@K']:.4f}")
        print(f"  Recall@K:      {rank_metrics['Recall@K']:.4f}")
        print(f"  MRR:           {rank_metrics['MRR']:.4f}")
        print(f"  nDCG@K:        {rank_metrics['nDCG@K']:.4f}")

Training Two-Tower:   2%|▏         | 1/50 [06:50<5:35:00, 410.22s/it]

[Epoch 1] | Loss: 0.5869


Training Two-Tower:   4%|▍         | 2/50 [13:35<5:26:03, 407.56s/it]

[Epoch 2] | Loss: 0.2913


Training Two-Tower:   6%|▌         | 3/50 [20:14<5:16:05, 403.51s/it]

[Epoch 3] | Loss: 0.1359


Training Two-Tower:   8%|▊         | 4/50 [26:52<5:07:45, 401.41s/it]

[Epoch 4] | Loss: 0.0778
[Epoch 5] | Loss: 0.0453
[Epoch 5] Pointwise Eval:
  ROC AUC:       0.6011
  Pairwise Acc:  0.6015
Evaluating users in batches...



  0%|          | 0/32 [00:00<?, ?it/s][A
  3%|▎         | 1/32 [00:00<00:07,  3.94it/s][A
  6%|▋         | 2/32 [00:00<00:07,  3.82it/s][A
  9%|▉         | 3/32 [00:00<00:08,  3.41it/s][A
 12%|█▎        | 4/32 [00:01<00:08,  3.47it/s][A
 16%|█▌        | 5/32 [00:01<00:07,  3.84it/s][A
 19%|█▉        | 6/32 [00:01<00:06,  4.12it/s][A
 22%|██▏       | 7/32 [00:01<00:05,  4.18it/s][A
 25%|██▌       | 8/32 [00:02<00:05,  4.05it/s][A
 28%|██▊       | 9/32 [00:02<00:05,  3.90it/s][A
 31%|███▏      | 10/32 [00:02<00:05,  4.03it/s][A
 34%|███▍      | 11/32 [00:02<00:04,  4.54it/s][A
 38%|███▊      | 12/32 [00:02<00:03,  5.27it/s][A
 41%|████      | 13/32 [00:02<00:03,  5.98it/s][A
 44%|████▍     | 14/32 [00:03<00:03,  5.84it/s][A
 47%|████▋     | 15/32 [00:03<00:02,  5.69it/s][A
 50%|█████     | 16/32 [00:03<00:02,  5.85it/s][A
 53%|█████▎    | 17/32 [00:03<00:02,  6.34it/s][A
 56%|█████▋    | 18/32 [00:03<00:02,  6.76it/s][A
 59%|█████▉    | 19/32 [00:03<00:01,  7.07it/s]

[Epoch 5] Retrieval Eval:
  Precision@K:   0.0000
  Recall@K:      0.0000
  MRR:           0.0000
  nDCG@K:        0.0000


In [None]:
torch.save(user_tower.state_dict(), f'user_tower_{timestamp}.pt')
torch.save(item_tower.state_dict(), f'item_tower_{timestamp}.pt')