In [4]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pandas as pd

from datetime import datetime
from tqdm import tqdm

import random

In [5]:
# torch.autograd.set_detect_anomaly(True)

user_features = pd.read_parquet('../datasets/user_features_clean.parquet')
ratings_groupped_ids = pd.read_parquet('../datasets/ratings_groupped_ids.parquet')

In [6]:
print(user_features.info())
print(ratings_groupped_ids.info())

empty_pos_ratings = ratings_groupped_ids['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = ratings_groupped_ids['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 198832 entries, 0 to 198831
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   198832 non-null  int64  
 1   num_rating               198832 non-null  float64
 2   avg_rating               198832 non-null  float64
 3   weekend_watcher          198832 non-null  float64
 4   genre_Action             198832 non-null  float64
 5   genre_Adventure          198832 non-null  float64
 6   genre_Animation          198832 non-null  float64
 7   genre_Comedy             198832 non-null  float64
 8   genre_Crime              198832 non-null  float64
 9   genre_Documentary        198832 non-null  float64
 10  genre_Drama              198832 non-null  float64
 11  genre_Family             198832 non-null  float64
 12  genre_Fantasy            198832 non-null  float64
 13  genre_History            198832 non-null  float64
 14  genr

# Mapowanie movieId do ciągłego przedziału liczb naturalnych, aby umożliwić użycie nn.Embedding

In [7]:
unique_ids = set(
        user_features['movies_seq'].explode().tolist()
        + ratings_groupped_ids['pos'].explode().tolist() 
        + ratings_groupped_ids['neg'].explode().tolist()
    )

print('Unique movieIds:', len(unique_ids))
unique_ids = sorted(unique_ids)

movieId_to_idx = {id_: idx for idx, id_ in enumerate(unique_ids)}
print('min idx:', min(movieId_to_idx.values()))
print('max idx:', max(movieId_to_idx.values()))

n_items = len(unique_ids)

assert min(movieId_to_idx.values()) == 0
assert max(movieId_to_idx.values()) == n_items - 1

Unique movieIds: 82932
min idx: 0
max idx: 82931


In [8]:
# Convert movieIds in ratings_groupped_ids to the ones accepted by nn.Embedding
def map_list(col):
    return [movieId_to_idx[m] for m in col]

for df, col in [
    (user_features, 'movies_seq'),
    (ratings_groupped_ids, 'pos'),
    (ratings_groupped_ids, 'neg')]:
    df[col] = df[col].apply(map_list)


max_idx = max(movieId_to_idx.values())
assert all(0 <= id_ <= max_idx for l in ratings_groupped_ids['pos'] for id_ in l)
assert all(0 <= id_ <= max_idx for l in ratings_groupped_ids['neg'] for id_ in l)

In [9]:
class UserDataset(Dataset):
    def __init__(self, user_features):
        self.data = user_features
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data.iloc[idx]



In [10]:
EMB_DIM = 64

class UserTower(nn.Module):
    def __init__(self, input_dim, n_items, embedding_dim=EMB_DIM):
        '''
        input_dim - the number of columns in user features, without sequence columns
        '''
        super().__init__()

        self.item_emb = nn.Embedding(n_items, embedding_dim)

        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights 
        x = torch.stack([batch['ratings'], batch['timestamps']], dim=-1)
        w = torch.sigmoid(self.rating_proj(x))

        # weighted mean-pool
        pooled = (m * w).sum(1) / (w.sum(1).clamp_min(1e-6))

        input = torch.cat([batch['user_features'], pooled], dim=-1)
        output = self.mlp(input)
        u = F.normalize(output, dim = 1)
        return u


In [11]:
n_items = len(unique_ids)

def collate(batch):
    user_features, movies, ratings, timestamps, pos, neg = [], [], [], [], [], []

    for row in batch:
        movies.append(torch.tensor(row['movies_seq'], dtype=torch.long))
        ratings.append(torch.tensor(row['ratings_seq'], dtype=torch.float32))
        timestamps.append(torch.tensor(row['ts_seq'], dtype=torch.float32))

        userId = row['userId']

        r = row[['num_rating', 'avg_rating', 'weekend_watcher', 'genre_Action', 'genre_Adventure', 'genre_Animation', 'genre_Comedy', 'genre_Crime', 'genre_Documentary', 'genre_Drama', 'genre_Family', 'genre_Fantasy', 'genre_History', 'genre_Horror', 'genre_Music', 'genre_Mystery', 'genre_Romance', 'genre_Science Fiction', 'genre_TV Movie', 'genre_Thriller', 'genre_War', 'genre_Western', 'type_of_viewer_negative', 'type_of_viewer_neutral', 'type_of_viewer_positive']]
        r = r.astype('float32').values
        user_features.append(torch.tensor(r, dtype=torch.float32))
        
        # Get a random movieId that was rated positively and one that was rated negatively. 
        # Used during training to calculate BPR loss. 
        posAndNegRow = ratings_groupped_ids[ratings_groupped_ids['userId'] == userId].iloc[0]
        pos.append(torch.tensor(random.choice(posAndNegRow['pos']), dtype=torch.long))
        neg.append(torch.tensor(random.choice(posAndNegRow['neg']), dtype=torch.long))

    return {
        "input": {
            "user_features": torch.stack(user_features),
            "movies": torch.stack(movies),
            "ratings": torch.stack(ratings),
            "timestamps": torch.stack(timestamps),
        },
        "pos": torch.as_tensor(pos, dtype=torch.long),
        "neg": torch.as_tensor(neg, dtype=torch.long)
    }

In [12]:
BATCH_SIZE = 4096

from sklearn.model_selection import train_test_split

train, test = train_test_split(user_features, test_size=0.2)

trainDataset = UserDataset(train)
trainDataLoader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)

testDataset = UserDataset(test)
testDataLoader = DataLoader(testDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')
print('Device:', device)

Device: cuda


In [13]:
# Commented out because lasts ~3 minutes and passes

# # Verify the data 
# def check_ids(tensor, column):
#     if (tensor < 0).any() or (tensor >= n_items).any():
#         raise ValueError(f"Out of range index in column {column}. Value: {tensor[(tensor<0) | (tensor >= n_items)]}")


# for row in tqdm(dataloader):
#     check_ids(row['input']['movies'], 'input.movies')
#     check_ids(row['pos'], 'pos')
#     check_ids(row['neg'], 'neg')

In [14]:
import torch.optim as optim

model = UserTower(input_dim=25, n_items=len(unique_ids)).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = F.mse_loss

In [15]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device)
    else:
        return data

number_of_batches = len(trainDataLoader)

def train_one_epoch():
    running_loss = 0.0
    total = 0

    for i, data in enumerate(trainDataLoader):
        data = to_device(data, device)
        optimizer.zero_grad()

        u = model(data['input'])
        pos_vec = model.item_emb(data['pos'])
        neg_vec = model.item_emb(data['neg'])

        pos_score = (u * pos_vec).sum(dim=-1)
        neg_score = (u * neg_vec).sum(dim=-1)
        # BPR Loss
        loss = -F.logsigmoid(pos_score - neg_score).mean()

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += 1

        # if i % int(number_of_batches * 0.1) == 0:
        #     print(f'Loss for batch {i}/{number_of_batches}: {running_loss / total:.4f}')
    
    epoch_loss = running_loss / total
    return epoch_loss


In [16]:
from sklearn.metrics import roc_auc_score
import numpy as np

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
EPOCHS = 50

for epoch in tqdm(range(EPOCHS)):
    model.train(True)
    avg_loss = train_one_epoch()

    print(f"Epoch {epoch + 1} loss: {avg_loss}")

    # Evaluation

    if epoch % 5 == 4:
        aucs = []
        pair_acc = []

        model.eval()
        with torch.no_grad():
            for batch in testDataLoader:
                batch = to_device(batch, device)

                u = model(batch['input'])
                pos_vec = model.item_emb(batch['pos'])
                neg_vec = model.item_emb(batch['neg'])

                pos_score = (u * pos_vec).sum(dim = -1)
                neg_score = (u * neg_vec).sum(dim = -1)

                # ROC AUC
                labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
                scores = torch.cat([pos_score, neg_score])
                aucs.append(roc_auc_score(labels.cpu(), scores.cpu()))

                # Pair-wise accuarcy
                acc = (pos_score > neg_score).float().mean().item()
                pair_acc.append(acc)

        print(f'Epoch {epoch + 1}. ROC AUC: {float(np.mean(aucs))}, Pair-wise accuaracy: {float(np.mean(pair_acc))}')


            

  2%|▏         | 1/50 [02:34<2:06:34, 155.00s/it]

Epoch 1 loss: 0.8126042103156065


  4%|▍         | 2/50 [05:14<2:05:58, 157.48s/it]

Epoch 2 loss: 0.7571149346155998


  6%|▌         | 3/50 [07:48<2:02:08, 155.92s/it]

Epoch 3 loss: 0.7189535315220172


  8%|▊         | 4/50 [10:25<1:59:49, 156.29s/it]

Epoch 4 loss: 0.6860592288848681
Epoch 5 loss: 0.6617677853657649


 10%|█         | 5/50 [13:35<2:06:33, 168.76s/it]

Epoch 5. ROC AUC: 0.6325082905841728, Pair-wise accuaracy: 0.6361523330211639


 12%|█▏        | 6/50 [16:08<1:59:45, 163.31s/it]

Epoch 6 loss: 0.6435369192025601


 14%|█▍        | 7/50 [18:41<1:54:29, 159.75s/it]

Epoch 7 loss: 0.6287687436128274


 16%|█▌        | 8/50 [21:13<1:50:10, 157.40s/it]

Epoch 8 loss: 0.6153981517522763


 18%|█▊        | 9/50 [23:46<1:46:32, 155.92s/it]

Epoch 9 loss: 0.60418640038906
Epoch 10 loss: 0.5956538823934702


 20%|██        | 10/50 [26:56<1:51:06, 166.65s/it]

Epoch 10. ROC AUC: 0.6756590251277825, Pair-wise accuaracy: 0.6830587565898896


 22%|██▏       | 11/50 [29:29<1:45:28, 162.26s/it]

Epoch 11 loss: 0.5893668028024527


 24%|██▍       | 12/50 [32:02<1:40:58, 159.44s/it]

Epoch 12 loss: 0.5819679086024945


 26%|██▌       | 13/50 [34:34<1:36:57, 157.23s/it]

Epoch 13 loss: 0.5776272859328833


 28%|██▊       | 14/50 [37:06<1:33:30, 155.85s/it]

Epoch 14 loss: 0.571480644054902
Epoch 15 loss: 0.5662625309748527


 30%|███       | 15/50 [40:17<1:37:01, 166.32s/it]

Epoch 15. ROC AUC: 0.6989596011800925, Pair-wise accuaracy: 0.7074283301830292


 32%|███▏      | 16/50 [42:50<1:31:55, 162.23s/it]

Epoch 16 loss: 0.5644027346219772


 34%|███▍      | 17/50 [45:23<1:27:41, 159.43s/it]

Epoch 17 loss: 0.5578987506719736


 36%|███▌      | 18/50 [47:55<1:23:53, 157.30s/it]

Epoch 18 loss: 0.5532011894079355


 38%|███▊      | 19/50 [50:28<1:20:35, 156.00s/it]

Epoch 19 loss: 0.5511576166519752
Epoch 20 loss: 0.5455219271855477


 40%|████      | 20/50 [53:38<1:23:08, 166.27s/it]

Epoch 20. ROC AUC: 0.7146952428596289, Pair-wise accuaracy: 0.7257957279682159


 42%|████▏     | 21/50 [56:11<1:18:28, 162.35s/it]

Epoch 21 loss: 0.5436377861560919


 44%|████▍     | 22/50 [58:44<1:14:23, 159.41s/it]

Epoch 22 loss: 0.5403453585429069


 46%|████▌     | 23/50 [1:01:16<1:10:45, 157.25s/it]

Epoch 23 loss: 0.5395209621160458


 48%|████▊     | 24/50 [1:03:49<1:07:35, 155.96s/it]

Epoch 24 loss: 0.5352153854492383
Epoch 25 loss: 0.5319998233746259


 50%|█████     | 25/50 [1:06:58<1:09:05, 165.83s/it]

Epoch 25. ROC AUC: 0.7200277497679555, Pair-wise accuaracy: 0.7331005454063415


 52%|█████▏    | 26/50 [1:09:30<1:04:39, 161.63s/it]

Epoch 26 loss: 0.5292629248056656


 54%|█████▍    | 27/50 [1:12:01<1:00:47, 158.60s/it]

Epoch 27 loss: 0.5256975063910851


 56%|█████▌    | 28/50 [1:14:32<57:18, 156.31s/it]  

Epoch 28 loss: 0.5249016284942627


 58%|█████▊    | 29/50 [1:17:04<54:11, 154.83s/it]

Epoch 29 loss: 0.5227055075841073
Epoch 30 loss: 0.520447952625079


 60%|██████    | 30/50 [1:20:13<55:01, 165.07s/it]

Epoch 30. ROC AUC: 0.7325619561501617, Pair-wise accuaracy: 0.7418752253055573


 62%|██████▏   | 31/50 [1:22:44<50:59, 161.02s/it]

Epoch 31 loss: 0.519850506232335


 64%|██████▍   | 32/50 [1:25:17<47:31, 158.41s/it]

Epoch 32 loss: 0.5173639563413767


 66%|██████▌   | 33/50 [1:27:48<44:16, 156.24s/it]

Epoch 33 loss: 0.5173166455366672


 68%|██████▊   | 34/50 [1:30:19<41:18, 154.88s/it]

Epoch 34 loss: 0.5137828290462494
Epoch 35 loss: 0.5101252297560374


 70%|███████   | 35/50 [1:33:28<41:15, 165.03s/it]

Epoch 35. ROC AUC: 0.7364991585483966, Pair-wise accuaracy: 0.7497582852840423


 72%|███████▏  | 36/50 [1:36:00<37:33, 160.98s/it]

Epoch 36 loss: 0.5087636671005151


 74%|███████▍  | 37/50 [1:38:31<34:15, 158.13s/it]

Epoch 37 loss: 0.5087475478649139


 76%|███████▌  | 38/50 [1:41:02<31:13, 156.09s/it]

Epoch 38 loss: 0.5072259627855741


 78%|███████▊  | 39/50 [1:43:34<28:22, 154.78s/it]

Epoch 39 loss: 0.5058857049697485
Epoch 40 loss: 0.502375453710556


 80%|████████  | 40/50 [1:46:44<27:31, 165.14s/it]

Epoch 40. ROC AUC: 0.7429751451104261, Pair-wise accuaracy: 0.7549966037273407


 82%|████████▏ | 41/50 [1:49:15<24:09, 161.05s/it]

Epoch 41 loss: 0.5016683500546676


 84%|████████▍ | 42/50 [1:51:46<21:05, 158.14s/it]

Epoch 42 loss: 0.5003491731790396


 86%|████████▌ | 43/50 [1:54:17<18:11, 155.98s/it]

Epoch 43 loss: 0.4990273072169377


 88%|████████▊ | 44/50 [1:56:49<15:28, 154.72s/it]

Epoch 44 loss: 0.49731707802185643
Epoch 45 loss: 0.49523939459751815


 90%|█████████ | 45/50 [1:59:58<13:44, 165.00s/it]

Epoch 45. ROC AUC: 0.7459261998994836, Pair-wise accuaracy: 0.7552544593811035


 92%|█████████▏| 46/50 [2:02:30<10:44, 161.04s/it]

Epoch 46 loss: 0.4959277503001384


 94%|█████████▍| 47/50 [2:05:02<07:54, 158.30s/it]

Epoch 47 loss: 0.4923977760168222


 96%|█████████▌| 48/50 [2:07:33<05:12, 156.30s/it]

Epoch 48 loss: 0.49336252304223865


 98%|█████████▊| 49/50 [2:10:05<02:34, 154.89s/it]

Epoch 49 loss: 0.4915953102784279
Epoch 50 loss: 0.4903713579361255


100%|██████████| 50/50 [2:13:14<00:00, 159.89s/it]

Epoch 50. ROC AUC: 0.7485580821716291, Pair-wise accuaracy: 0.7581757843494416





In [17]:
torch.save(model.state_dict(), f'user_tower_{timestamp}.model')