In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pandas as pd

from datetime import datetime
from tqdm import tqdm

import random

In [2]:
# torch.autograd.set_detect_anomaly(True)

user_features = pd.read_parquet('../datasets/user_features_clean.parquet')
ratings_groupped_ids = pd.read_parquet('../datasets/ratings_groupped_ids.parquet')

In [3]:
print(user_features.info())
print(ratings_groupped_ids.info())

empty_pos_ratings = ratings_groupped_ids['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = ratings_groupped_ids['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 198832 entries, 0 to 198831
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   198832 non-null  int64  
 1   num_rating               198832 non-null  float64
 2   avg_rating               198832 non-null  float64
 3   weekend_watcher          198832 non-null  float64
 4   genre_Action             198832 non-null  float64
 5   genre_Adventure          198832 non-null  float64
 6   genre_Animation          198832 non-null  float64
 7   genre_Comedy             198832 non-null  float64
 8   genre_Crime              198832 non-null  float64
 9   genre_Documentary        198832 non-null  float64
 10  genre_Drama              198832 non-null  float64
 11  genre_Family             198832 non-null  float64
 12  genre_Fantasy            198832 non-null  float64
 13  genre_History            198832 non-null  float64
 14  genr

# Mapowanie movieId do ciągłego przedziału liczb naturalnych, aby umożliwić użycie nn.Embedding

In [4]:
unique_ids = set(
        user_features['movies_seq'].explode().tolist()
        + ratings_groupped_ids['pos'].explode().tolist() 
        + ratings_groupped_ids['neg'].explode().tolist()
    )

print('Unique movieIds:', len(unique_ids))
unique_ids = sorted(unique_ids)

movieId_to_idx = {id_: idx for idx, id_ in enumerate(unique_ids)}
print('min idx:', min(movieId_to_idx.values()))
print('max idx:', max(movieId_to_idx.values()))

n_items = len(unique_ids)

assert min(movieId_to_idx.values()) == 0
assert max(movieId_to_idx.values()) == n_items - 1

Unique movieIds: 82932
min idx: 0
max idx: 82931


In [5]:
# Convert movieIds in ratings_groupped_ids to the ones accepted by nn.Embedding
def map_list(col):
    return [movieId_to_idx[m] for m in col]

for df, col in [
    (user_features, 'movies_seq'),
    (ratings_groupped_ids, 'pos'),
    (ratings_groupped_ids, 'neg')]:
    df[col] = df[col].apply(map_list)


max_idx = max(movieId_to_idx.values())
assert all(0 <= id_ <= max_idx for l in ratings_groupped_ids['pos'] for id_ in l)
assert all(0 <= id_ <= max_idx for l in ratings_groupped_ids['neg'] for id_ in l)

In [6]:
class UserDataset(Dataset):
    def __init__(self, user_features):
        self.data = user_features
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data.iloc[idx]



In [7]:
EMB_DIM = 64

class UserTower(nn.Module):
    def __init__(self, input_dim, n_items, embedding_dim=EMB_DIM):
        '''
        input_dim - the number of columns in user features, without sequence columns
        '''
        super().__init__()

        self.item_emb = nn.Embedding(n_items, embedding_dim)

        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights 
        x = torch.stack([batch['ratings'], batch['timestamps']], dim=-1)
        w = torch.sigmoid(self.rating_proj(x))

        # weighted mean-pool
        pooled = (m * w).sum(1) / (w.sum(1).clamp_min(1e-6))

        input = torch.cat([batch['user_features'], pooled], dim=-1)
        output = self.mlp(input)
        u = F.normalize(output, dim = 1)
        return u


In [8]:
n_items = len(unique_ids)

def collate(batch):
    user_features, movies, ratings, timestamps, pos, neg = [], [], [], [], [], []

    for row in batch:
        movies.append(torch.tensor(row['movies_seq'], dtype=torch.long))
        ratings.append(torch.tensor(row['ratings_seq'], dtype=torch.float32))
        timestamps.append(torch.tensor(row['ts_seq'], dtype=torch.float32))

        userId = row['userId']

        r = row[['num_rating', 'avg_rating', 'weekend_watcher', 'genre_Action', 'genre_Adventure', 'genre_Animation', 'genre_Comedy', 'genre_Crime', 'genre_Documentary', 'genre_Drama', 'genre_Family', 'genre_Fantasy', 'genre_History', 'genre_Horror', 'genre_Music', 'genre_Mystery', 'genre_Romance', 'genre_Science Fiction', 'genre_TV Movie', 'genre_Thriller', 'genre_War', 'genre_Western', 'type_of_viewer_negative', 'type_of_viewer_neutral', 'type_of_viewer_positive']]
        r = r.astype('float32').values
        user_features.append(torch.tensor(r, dtype=torch.float32))
        
        # Get a random movieId that was rated positively and one that was rated negatively. 
        # Used during training to calculate BPR loss. 
        posAndNegRow = ratings_groupped_ids[ratings_groupped_ids['userId'] == userId].iloc[0]
        pos.append(torch.tensor(random.choice(posAndNegRow['pos']), dtype=torch.long))
        neg.append(torch.tensor(random.choice(posAndNegRow['neg']), dtype=torch.long))

    return {
        "input": {
            "user_features": torch.stack(user_features),
            "movies": torch.stack(movies),
            "ratings": torch.stack(ratings),
            "timestamps": torch.stack(timestamps),
        },
        "pos": torch.as_tensor(pos, dtype=torch.long),
        "neg": torch.as_tensor(neg, dtype=torch.long)
    }

In [9]:
BATCH_SIZE = 4096

from sklearn.model_selection import train_test_split

train, test = train_test_split(user_features, test_size=0.2)

trainDataset = UserDataset(train)
trainDataLoader = DataLoader(trainDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)

testDataset = UserDataset(test)
testDataLoader = DataLoader(testDataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')
print('Device:', device)

Device: mps


In [10]:
# Commented out because lasts ~3 minutes and passes

# # Verify the data 
# def check_ids(tensor, column):
#     if (tensor < 0).any() or (tensor >= n_items).any():
#         raise ValueError(f"Out of range index in column {column}. Value: {tensor[(tensor<0) | (tensor >= n_items)]}")


# for row in tqdm(dataloader):
#     check_ids(row['input']['movies'], 'input.movies')
#     check_ids(row['pos'], 'pos')
#     check_ids(row['neg'], 'neg')

In [11]:
import torch.optim as optim

model = UserTower(input_dim=25, n_items=len(unique_ids)).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = F.mse_loss

In [12]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device)
    else:
        return data

number_of_batches = len(trainDataLoader)

def train_one_epoch():
    running_loss = 0.0
    total = 0

    for i, data in enumerate(trainDataLoader):
        data = to_device(data, device)
        optimizer.zero_grad()

        u = model(data['input'])
        pos_vec = model.item_emb(data['pos'])
        neg_vec = model.item_emb(data['neg'])

        pos_score = (u * pos_vec).sum(dim=-1)
        neg_score = (u * neg_vec).sum(dim=-1)
        # BPR Loss
        loss = -F.logsigmoid(pos_score - neg_score).mean()

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += 1

        # if i % int(number_of_batches * 0.1) == 0:
        #     print(f'Loss for batch {i}/{number_of_batches}: {running_loss / total:.4f}')
    
    epoch_loss = running_loss / total
    return epoch_loss


In [None]:
from sklearn.metrics import roc_auc_score
import numpy as np

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
EPOCHS = 50

for epoch in tqdm(range(EPOCHS)):
    model.train(True)
    avg_loss = train_one_epoch()

    print(f"Epoch {epoch + 1} loss: {avg_loss}")

    # Evaluation

    if epoch % 5 == 4:
        aucs = []
        pair_acc = []

        model.eval()
        with torch.no_grad():
            for batch in testDataLoader:
                batch = to_device(batch, device)

                u = model(batch['input'])
                pos_vec = model.item_emb(batch['pos'])
                neg_vec = model.item_emb(batch['neg'])

                pos_score = (u * pos_vec).sum(dim = -1)
                neg_score = (u * neg_vec).sum(dim = -1)

                # ROC AUC
                labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
                scores = torch.cat([pos_score, neg_score])
                aucs.append(roc_auc_score(labels.cpu(), scores.cpu()))

                # Pair-wise accuarcy
                acc = (pos_score > neg_score).float().mean().item()

        print(f'Epoch {epoch + 1}. ROC AUC: {float(np.mean(aucs))}, Pair-wise accuaracy: {float(np.mean(pair_acc))}')


            

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1 loss: 0.8018107276696426


  2%|▏         | 1/50 [01:14<1:01:14, 74.98s/it]

Epoch 1 ROC AUC: 0.5577064058859318


  4%|▍         | 2/50 [02:12<51:48, 64.75s/it]  

Epoch 2 loss: 0.7500845117446704


  6%|▌         | 3/50 [03:10<48:12, 61.54s/it]

Epoch 3 loss: 0.7103847861289978


  8%|▊         | 4/50 [04:06<45:33, 59.43s/it]

Epoch 4 loss: 0.682935757514758


 10%|█         | 5/50 [05:02<43:36, 58.15s/it]

Epoch 5 loss: 0.6602280353888487


 12%|█▏        | 6/50 [05:58<42:04, 57.38s/it]

Epoch 6 loss: 0.6414574201290424


 14%|█▍        | 7/50 [06:54<40:48, 56.95s/it]

Epoch 7 loss: 0.6279569910122798


 16%|█▌        | 8/50 [07:55<40:48, 58.31s/it]

Epoch 8 loss: 0.6158155890611502


 18%|█▊        | 9/50 [08:49<38:58, 57.03s/it]

Epoch 9 loss: 0.6074109918031937


 20%|██        | 10/50 [09:44<37:27, 56.19s/it]

Epoch 10 loss: 0.5977927782596686
Epoch 11 loss: 0.5914914852533585


 22%|██▏       | 11/50 [10:51<38:48, 59.71s/it]

Epoch 11 ROC AUC: 0.6768392874204945


 24%|██▍       | 12/50 [11:46<36:47, 58.10s/it]

Epoch 12 loss: 0.5856108298668494


 26%|██▌       | 13/50 [12:40<35:04, 56.87s/it]

Epoch 13 loss: 0.5821403891612322


 28%|██▊       | 14/50 [13:34<33:38, 56.07s/it]

Epoch 14 loss: 0.5777470500041277


 30%|███       | 15/50 [14:28<32:21, 55.47s/it]

Epoch 15 loss: 0.5731931695571313


 32%|███▏      | 16/50 [15:22<31:13, 55.12s/it]

Epoch 16 loss: 0.5688190597754258


 34%|███▍      | 17/50 [16:16<30:09, 54.83s/it]

Epoch 17 loss: 0.5643275899764819


 36%|███▌      | 18/50 [17:11<29:08, 54.64s/it]

Epoch 18 loss: 0.560099551310906


 38%|███▊      | 19/50 [18:05<28:09, 54.48s/it]

Epoch 19 loss: 0.5602431129186581


 40%|████      | 20/50 [18:59<27:10, 54.36s/it]

Epoch 20 loss: 0.5564005543024112
Epoch 21 loss: 0.551607825817206


 42%|████▏     | 21/50 [20:07<28:13, 58.40s/it]

Epoch 21 ROC AUC: 0.7006610319519238


 44%|████▍     | 22/50 [21:01<26:39, 57.11s/it]

Epoch 22 loss: 0.5491172793583993


 46%|████▌     | 23/50 [21:55<25:16, 56.16s/it]

Epoch 23 loss: 0.5472998313414745


 48%|████▊     | 24/50 [22:49<24:04, 55.54s/it]

Epoch 24 loss: 0.5453726099087641


 50%|█████     | 25/50 [23:43<22:57, 55.10s/it]

Epoch 25 loss: 0.5400536106182978


 52%|█████▏    | 26/50 [24:37<21:54, 54.77s/it]

Epoch 26 loss: 0.5386487535941296


 54%|█████▍    | 27/50 [25:31<20:56, 54.64s/it]

Epoch 27 loss: 0.5366421387745783


 56%|█████▌    | 28/50 [26:25<19:58, 54.48s/it]

Epoch 28 loss: 0.5323177851163424


 58%|█████▊    | 29/50 [27:19<19:01, 54.35s/it]

Epoch 29 loss: 0.5310478791212424


 60%|██████    | 30/50 [28:14<18:06, 54.31s/it]

Epoch 30 loss: 0.5269046945449634
Epoch 31 loss: 0.5263647865026425


 62%|██████▏   | 31/50 [29:21<18:28, 58.32s/it]

Epoch 31 ROC AUC: 0.7241083374263156


 64%|██████▍   | 32/50 [30:16<17:07, 57.10s/it]

Epoch 32 loss: 0.5245819382178478


 66%|██████▌   | 33/50 [31:10<15:55, 56.20s/it]

Epoch 33 loss: 0.5215428884212787


 68%|██████▊   | 34/50 [32:04<14:49, 55.59s/it]

Epoch 34 loss: 0.5213764890646323


 70%|███████   | 35/50 [32:58<13:48, 55.24s/it]

Epoch 35 loss: 0.5189568385099753


 72%|███████▏  | 36/50 [33:52<12:48, 54.92s/it]

Epoch 36 loss: 0.5165863396265568


 74%|███████▍  | 37/50 [34:47<11:51, 54.71s/it]

Epoch 37 loss: 0.5143485604188381


 76%|███████▌  | 38/50 [35:41<10:54, 54.53s/it]

Epoch 38 loss: 0.5117840392467303


 78%|███████▊  | 39/50 [36:35<09:58, 54.41s/it]

Epoch 39 loss: 0.5091085709058322


 80%|████████  | 40/50 [37:29<09:03, 54.38s/it]

Epoch 40 loss: 0.5085443013753647
Epoch 41 loss: 0.5073813062447768


 82%|████████▏ | 41/50 [38:37<08:44, 58.29s/it]

Epoch 41 ROC AUC: 0.7390895309390586


 84%|████████▍ | 42/50 [39:31<07:36, 57.02s/it]

Epoch 42 loss: 0.5051175699784205


 86%|████████▌ | 43/50 [40:25<06:32, 56.14s/it]

Epoch 43 loss: 0.5052691934964596


 88%|████████▊ | 44/50 [41:19<05:33, 55.61s/it]

Epoch 44 loss: 0.5018550631327506


 90%|█████████ | 45/50 [42:13<04:35, 55.16s/it]

Epoch 45 loss: 0.49971548104897523


 92%|█████████▏| 46/50 [43:07<03:39, 54.86s/it]

Epoch 46 loss: 0.4990856892023331


 94%|█████████▍| 47/50 [44:02<02:43, 54.67s/it]

Epoch 47 loss: 0.4995819681730026


 96%|█████████▌| 48/50 [44:56<01:49, 54.57s/it]

Epoch 48 loss: 0.4930452673863142


 98%|█████████▊| 49/50 [45:50<00:54, 54.46s/it]

Epoch 49 loss: 0.492585749962391


100%|██████████| 50/50 [46:44<00:00, 56.09s/it]

Epoch 50 loss: 0.4944544931252797





In [14]:
torch.save(model.state_dict(), f'user_tower_{timestamp}.model')