In [367]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np

In [368]:
user_features = pd.read_parquet('../datasets/user_features_clean.parquet')
ratings_groupped_ids = pd.read_parquet('../datasets/ratings_groupped_ids.parquet')

In [None]:
print(user_features.info())
print(ratings_groupped_ids.info())

empty_pos_ratings = ratings_groupped_ids['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = ratings_groupped_ids['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 198832 entries, 0 to 198831
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   198832 non-null  int64  
 1   num_rating               198832 non-null  float64
 2   avg_rating               198832 non-null  float64
 3   weekend_watcher          198832 non-null  float64
 4   genre_Action             198832 non-null  float64
 5   genre_Adventure          198832 non-null  float64
 6   genre_Animation          198832 non-null  float64
 7   genre_Comedy             198832 non-null  float64
 8   genre_Crime              198832 non-null  float64
 9   genre_Documentary        198832 non-null  float64
 10  genre_Drama              198832 non-null  float64
 11  genre_Family             198832 non-null  float64
 12  genre_Fantasy            198832 non-null  float64
 13  genre_History            198832 non-null  float64
 14  genr

In [370]:
class UserDataset(Dataset):
    def __init__(self, user_features):
        self.data = user_features
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data.iloc[idx]



In [None]:
EMB_DIM = 64
PAD_ID = 0

class UserTower(nn.Module):
    def __init__(self, input_dim, n_items, embedding_dim=EMB_DIM):
        '''
        input_dim - the number of columns in user features, without sequence columns
        '''
        super().__init__()

        self.item_emb = nn.Embedding(n_items + 1, embedding_dim, padding_idx=PAD_ID)

        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights 
        x = torch.stack([batch['ratings'], batch['timestamps']], dim=-1)
        w = torch.sigmoid(self.rating_proj(x))

        mask = (batch['movies'] != PAD_ID).unsqueeze(-1) # ????????????
        w = w * mask

        # weighted mean-pool
        pooled = (m * w).sum(1) / (w.sum(1).clamp_min(1e-6))

        input = torch.cat([batch['user_features'], pooled], dim=-1)
        output = self.mlp(input)
        u = F.normalize(output, dim = 1)
        return u


In [372]:
import random

def collate(batch):
    user_features, movies, ratings, timestamps, pos, neg = [], [], [], [], [], []

    for row in batch:
        movies.append(torch.tensor(row['movies_seq'], dtype=torch.long))
        ratings.append(torch.tensor(row['ratings_seq'], dtype=torch.float32))
        timestamps.append(torch.tensor(row['ts_seq'], dtype=torch.float32))

        userId = row['userId']

        r = row[['num_rating', 'avg_rating', 'weekend_watcher', 'genre_Action', 'genre_Adventure', 'genre_Animation', 'genre_Comedy', 'genre_Crime', 'genre_Documentary', 'genre_Drama', 'genre_Family', 'genre_Fantasy', 'genre_History', 'genre_Horror', 'genre_Music', 'genre_Mystery', 'genre_Romance', 'genre_Science Fiction', 'genre_TV Movie', 'genre_Thriller', 'genre_War', 'genre_Western', 'type_of_viewer_negative', 'type_of_viewer_neutral', 'type_of_viewer_positive']]
        r = r.astype('float32').values
        user_features.append(torch.tensor(r, dtype=torch.float32))
        
        # Get a random movieId that was rated positively and one that was rated negatively. 
        # Used during training to calculate BPR loss. 
        posAndNegRow = ratings_groupped_ids[ratings_groupped_ids['userId'] == userId].iloc[0]
        pos.append(torch.tensor(random.choice(posAndNegRow['pos']), dtype=torch.long))
        neg.append(torch.tensor(random.choice(posAndNegRow['neg']), dtype=torch.long))

    return {
        "input": {
            "user_features": torch.stack(user_features),
            "movies": torch.stack(movies),
            "ratings": torch.stack(ratings),
            "timestamps": torch.stack(timestamps),
        },
        "pos": torch.as_tensor(pos, dtype=torch.long),
        "neg": torch.as_tensor(neg, dtype=torch.long)
    }

In [373]:
BATCH_SIZE = 256

dataset = UserDataset(user_features)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=collate)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')
print('Device:', device)

Device: mps


In [374]:
import torch.optim as optim

model = UserTower(input_dim=25, n_items=86477).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = F.mse_loss

In [375]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device)
    else:
        return data

def train_one_epoch():
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(dataloader):
        data = to_device(data, device)
        optimizer.zero_grad()

        u = model(data['input'])
        pos_vec = model.item_emb(data['pos'])
        neg_vec = model.item_emb(data['neg'])

        pos_score = (u * pos_vec).sum(dim=-1)
        neg_score = (u * neg_vec).sum(dim=-1)
        # BPR Loss
        loss = -F.logsigmoid(pos_score - neg_score).mean()

        loss.backward()
        optimizer.step()

        print(f"Loss {loss.item():.4f}")

        running_loss += loss.item()
        if i % BATCH_SIZE == BATCH_SIZE - 1:
            last_loss = running_loss / BATCH_SIZE
            running_loss = 0.0
    
    return last_loss


In [376]:
from datetime import datetime
from tqdm import tqdm

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
EPOCHS = 5

best_vloss = 1_000_000

for epoch in tqdm(range(EPOCHS)):
    model.train(True)
    avg_loss = train_one_epoch()

    print(f"Epoch {epoch} loss: {avg_loss}")


  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[[ 0.2018,  0.0913,  1.0534,  ..., -0.5833,  0.3585,  0.1800],
         [-1.7587,  0.0679, -0.5471,  ...,  0.7515, -0.8811,  0.4126],
         [-0.5608, -0.0486,  2.5319,  ...,  1.1147, -0.4309, -1.1771],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0036,  1.1810, -0.0382,  ..., -1.9933,  0.8092, -0.2423],
         [-0.5264, -1.3265,  0.5017,  ...,  0.4687, -1.2012, -1.1992]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.0544,  1.8664,  0.6685,  ...,  0.2453, -1.7929,  0.2504],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0175, -1.4615, -0.8880,  ..., -1.5326, -1.2142, -0.8367],
         [ 0.3423,  0.8463,  1.0878,  ..., -0

  0%|          | 0/5 [00:11<?, ?it/s]

tensor([[[-1.1608, -0.6259, -0.2744,  ..., -2.2534,  0.6946,  1.4416],
         [-2.5224, -1.2694,  0.2216,  ..., -0.3758, -0.4260,  0.5676],
         [ 0.3907, -0.2048,  0.8043,  ..., -0.9131, -0.8029,  1.5353],
         ...,
         [-0.1394,  0.6550,  0.4033,  ..., -0.4418,  2.1019,  0.1293],
         [-0.4139, -1.4262, -2.4583,  ..., -0.3608, -0.5740, -0.7085],
         [ 0.8552, -1.6553, -0.2790,  ..., -1.6830, -0.3566,  0.7237]],

        [[-2.2739,  0.5030,  0.6034,  ...,  0.9267, -0.3998,  0.1953],
         [-0.2742, -1.4443, -1.1132,  ..., -0.9911, -0.0740, -1.7443],
         [-1.2019,  1.4730, -0.5805,  ...,  0.6086,  0.4204, -1.8490],
         ...,
         [ 0.4977,  1.0047,  0.4292,  ...,  0.6661, -1.2849, -2.0172],
         [-1.2186, -1.5485, -0.6473,  ...,  1.7477, -0.3332, -1.6221],
         [ 0.4249, -0.0360,  0.8686,  ..., -0.5766, -1.6140, -0.0182]],

        [[ 1.3591,  0.1374,  0.9702,  ...,  1.4643,  1.7539,  1.8391],
         [ 0.0000,  0.0000,  0.0000,  ...,  0




KeyboardInterrupt: 