In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pandas as pd

from datetime import datetime
from tqdm import tqdm

import random

In [None]:
# torch.autograd.set_detect_anomaly(True)

user_features = pd.read_parquet('../datasets/user_features_clean.parquet')
ratings_groupped_ids = pd.read_parquet('../datasets/ratings_groupped_ids.parquet')

In [None]:
print(user_features.info())
print(ratings_groupped_ids.info())

empty_pos_ratings = ratings_groupped_ids['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = ratings_groupped_ids['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

# Mapowanie movieId do ciągłego przedziału liczb naturalnych, aby umożliwić użycie nn.Embedding

In [None]:
unique_ids = set(
        user_features['movies_seq'].explode().tolist()
        + ratings_groupped_ids['pos'].explode().tolist() 
        + ratings_groupped_ids['neg'].explode().tolist()
    )

print('Unique movieIds:', len(unique_ids))
unique_ids = sorted(unique_ids)

movieId_to_idx = {id_: idx for idx, id_ in enumerate(unique_ids)}
print('min idx:', min(movieId_to_idx.values()))
print('max idx:', max(movieId_to_idx.values()))

n_items = len(unique_ids)

assert min(movieId_to_idx.values()) == 0
assert max(movieId_to_idx.values()) == n_items - 1

In [None]:
# Convert movieIds in ratings_groupped_ids to the ones accepted by nn.Embedding
def map_list(col):
    return [movieId_to_idx[m] for m in col]

for df, col in [
    (user_features, 'movies_seq'),
    (ratings_groupped_ids, 'pos'),
    (ratings_groupped_ids, 'neg')]:
    df[col] = df[col].apply(map_list)


max_idx = max(movieId_to_idx.values())
assert all(0 <= id_ <= max_idx for l in ratings_groupped_ids['pos'] for id_ in l)
assert all(0 <= id_ <= max_idx for l in ratings_groupped_ids['neg'] for id_ in l)

In [None]:
class UserDataset(Dataset):
    def __init__(self, user_features):
        self.data = user_features
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data.iloc[idx]



In [None]:
EMB_DIM = 64

class UserTower(nn.Module):
    def __init__(self, input_dim, n_items, embedding_dim=EMB_DIM):
        '''
        input_dim - the number of columns in user features, without sequence columns
        '''
        super().__init__()

        self.item_emb = nn.Embedding(n_items, embedding_dim)

        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights 
        x = torch.stack([batch['ratings'], batch['timestamps']], dim=-1)
        w = torch.sigmoid(self.rating_proj(x))

        # weighted mean-pool
        pooled = (m * w).sum(1) / (w.sum(1).clamp_min(1e-6))

        input = torch.cat([batch['user_features'], pooled], dim=-1)
        output = self.mlp(input)
        u = F.normalize(output, dim = 1)
        return u


In [None]:
n_items = len(unique_ids)

def collate(batch):
    user_features, movies, ratings, timestamps, pos, neg = [], [], [], [], [], []

    for row in batch:
        movies.append(torch.tensor(row['movies_seq'], dtype=torch.long))
        ratings.append(torch.tensor(row['ratings_seq'], dtype=torch.float32))
        timestamps.append(torch.tensor(row['ts_seq'], dtype=torch.float32))

        userId = row['userId']

        r = row[['num_rating', 'avg_rating', 'weekend_watcher', 'genre_Action', 'genre_Adventure', 'genre_Animation', 'genre_Comedy', 'genre_Crime', 'genre_Documentary', 'genre_Drama', 'genre_Family', 'genre_Fantasy', 'genre_History', 'genre_Horror', 'genre_Music', 'genre_Mystery', 'genre_Romance', 'genre_Science Fiction', 'genre_TV Movie', 'genre_Thriller', 'genre_War', 'genre_Western', 'type_of_viewer_negative', 'type_of_viewer_neutral', 'type_of_viewer_positive']]
        r = r.astype('float32').values
        user_features.append(torch.tensor(r, dtype=torch.float32))
        
        # Get a random movieId that was rated positively and one that was rated negatively. 
        # Used during training to calculate BPR loss. 
        posAndNegRow = ratings_groupped_ids[ratings_groupped_ids['userId'] == userId].iloc[0]
        pos.append(torch.tensor(random.choice(posAndNegRow['pos']), dtype=torch.long))
        neg.append(torch.tensor(random.choice(posAndNegRow['neg']), dtype=torch.long))

    return {
        "input": {
            "user_features": torch.stack(user_features),
            "movies": torch.stack(movies),
            "ratings": torch.stack(ratings),
            "timestamps": torch.stack(timestamps),
        },
        "pos": torch.as_tensor(pos, dtype=torch.long),
        "neg": torch.as_tensor(neg, dtype=torch.long)
    }

In [None]:
BATCH_SIZE = 4096

dataset = UserDataset(user_features)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=collate)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')
print('Device:', device)

In [None]:
# Commented out because lasts ~3 minutes and passes

# # Verify the data 
# def check_ids(tensor, column):
#     if (tensor < 0).any() or (tensor >= n_items).any():
#         raise ValueError(f"Out of range index in column {column}. Value: {tensor[(tensor<0) | (tensor >= n_items)]}")


# for row in tqdm(dataloader):
#     check_ids(row['input']['movies'], 'input.movies')
#     check_ids(row['pos'], 'pos')
#     check_ids(row['neg'], 'neg')

In [None]:
import torch.optim as optim

model = UserTower(input_dim=25, n_items=len(unique_ids)).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = F.mse_loss

In [None]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device)
    else:
        return data

number_of_batches = len(dataloader)

def train_one_epoch():
    running_loss = 0.0
    total = 0

    for i, data in enumerate(dataloader):
        data = to_device(data, device)
        optimizer.zero_grad()

        u = model(data['input'])
        pos_vec = model.item_emb(data['pos'])
        neg_vec = model.item_emb(data['neg'])

        pos_score = (u * pos_vec).sum(dim=-1)
        neg_score = (u * neg_vec).sum(dim=-1)
        # BPR Loss
        loss = -F.logsigmoid(pos_score - neg_score).mean()

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += 1

        # if i % int(number_of_batches * 0.1) == 0:
        #     print(f'Loss for batch {i}/{number_of_batches}: {running_loss / total:.4f}')
    
    epoch_loss = running_loss / total
    return epoch_loss


In [None]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
EPOCHS = 50

for epoch in tqdm(range(EPOCHS)):
    model.train(True)
    avg_loss = train_one_epoch()

    print(f"Epoch {epoch + 1} loss: {avg_loss}")


In [None]:
torch.save(model.state_dict(), 'user_tower.model')