In [21]:
import torch.nn as nn
import torch
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np

In [22]:
user_features = pd.read_parquet('../datasets/user_features_clean.parquet')
ratings_groupped_ids = pd.read_parquet('../datasets/ratings_groupped_ids.parquet')

In [23]:
print(user_features.info())
print(ratings_groupped_ids.info())

empty_pos_ratings = ratings_groupped_ids['pos'].apply(lambda x: len(x) == 0).sum()
empty_neg_ratings = ratings_groupped_ids['neg'].apply(lambda x: len(x) == 0).sum()

if empty_pos_ratings != 0 or empty_neg_ratings != 0:
    print(f'Empty ratings: pos: {empty_pos_ratings}, neg: {empty_neg_ratings}')
    raise Exception("Users without a single pos/neg rating exist in the ratings_groupped_ids dataset")

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 198832 entries, 0 to 198831
Data columns (total 29 columns):
 #   Column                   Non-Null Count   Dtype  
---  ------                   --------------   -----  
 0   userId                   198832 non-null  int64  
 1   num_rating               198832 non-null  float64
 2   avg_rating               198832 non-null  float64
 3   weekend_watcher          198832 non-null  float64
 4   genre_Action             198832 non-null  float64
 5   genre_Adventure          198832 non-null  float64
 6   genre_Animation          198832 non-null  float64
 7   genre_Comedy             198832 non-null  float64
 8   genre_Crime              198832 non-null  float64
 9   genre_Documentary        198832 non-null  float64
 10  genre_Drama              198832 non-null  float64
 11  genre_Family             198832 non-null  float64
 12  genre_Fantasy            198832 non-null  float64
 13  genre_History            198832 non-null  float64
 14  genr

In [24]:
class UserDataset(Dataset):
    def __init__(self, user_features):
        self.data = user_features
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data.iloc[idx]



In [25]:
EMB_DIM = 64
PAD_ID = 0

class UserTower(nn.Module):
    def __init__(self, input_dim, n_items, embedding_dim=EMB_DIM):
        '''
        input_dim - the number of columns in user features, without sequence columns
        '''
        super().__init__()

        self.item_emb = nn.Embedding(n_items + 1, embedding_dim, padding_idx=PAD_ID)

        # A layer to project rating and timestamp into a scalar weight
        self.rating_proj = nn.Linear(2, 1)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 384),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embedding_dim)
        )
    
    def forward(self, batch):
        # Embed movieIds liked by user
        m = self.item_emb(batch['movies'])

        # Get weights 
        x = torch.stack([batch['ratings'], batch['timestamps']], dim=-1)
        w = torch.sigmoid(self.rating_proj(x))

        mask = (batch['movies'] != PAD_ID).unsqueeze(-1) # ????????????
        w = w * mask

        # weighted mean-pool
        pooled = (m * w).sum(1) / (w.sum(1).clamp_min(1e-6))

        input = torch.cat([batch['user_features'], pooled], dim=-1)
        output = self.mlp(input)
        u = F.normalize(output, dim = 1)
        return u


In [26]:
import random

def collate(batch):
    user_features, movies, ratings, timestamps, pos, neg = [], [], [], [], [], []

    for row in batch:
        movies.append(torch.tensor(row['movies_seq'], dtype=torch.long))
        ratings.append(torch.tensor(row['ratings_seq'], dtype=torch.float32))
        timestamps.append(torch.tensor(row['ts_seq'], dtype=torch.float32))

        userId = row['userId']

        r = row[['num_rating', 'avg_rating', 'weekend_watcher', 'genre_Action', 'genre_Adventure', 'genre_Animation', 'genre_Comedy', 'genre_Crime', 'genre_Documentary', 'genre_Drama', 'genre_Family', 'genre_Fantasy', 'genre_History', 'genre_Horror', 'genre_Music', 'genre_Mystery', 'genre_Romance', 'genre_Science Fiction', 'genre_TV Movie', 'genre_Thriller', 'genre_War', 'genre_Western', 'type_of_viewer_negative', 'type_of_viewer_neutral', 'type_of_viewer_positive']]
        r = r.astype('float32').values
        user_features.append(torch.tensor(r, dtype=torch.float32))
        
        # Get a random movieId that was rated positively and one that was rated negatively. 
        # Used during training to calculate BPR loss. 
        posAndNegRow = ratings_groupped_ids[ratings_groupped_ids['userId'] == userId].iloc[0]
        pos.append(torch.tensor(random.choice(posAndNegRow['pos']), dtype=torch.long))
        neg.append(torch.tensor(random.choice(posAndNegRow['neg']), dtype=torch.long))

    return {
        "input": {
            "user_features": torch.stack(user_features),
            "movies": torch.stack(movies),
            "ratings": torch.stack(ratings),
            "timestamps": torch.stack(timestamps),
        },
        "pos": torch.as_tensor(pos, dtype=torch.long),
        "neg": torch.as_tensor(neg, dtype=torch.long)
    }

In [27]:
BATCH_SIZE = 256

dataset = UserDataset(user_features)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=collate)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')
print('Device:', device)

Device: mps


In [28]:
import torch.optim as optim

model = UserTower(input_dim=25, n_items=86477).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = F.mse_loss

In [29]:
def to_device(data, device):
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    elif torch.is_tensor(data):
        return data.to(device)
    else:
        return data

def train_one_epoch():
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(dataloader):
        data = to_device(data, device)
        optimizer.zero_grad()

        u = model(data['input'])
        pos_vec = model.item_emb(data['pos'])
        neg_vec = model.item_emb(data['neg'])

        pos_score = (u * pos_vec).sum(dim=-1)
        neg_score = (u * neg_vec).sum(dim=-1)
        # BPR Loss
        loss = -F.logsigmoid(pos_score - neg_score).mean()

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % BATCH_SIZE == BATCH_SIZE - 1:
            last_loss = running_loss / BATCH_SIZE
            running_loss = 0.0
    
    return last_loss


In [None]:
from datetime import datetime
from tqdm import tqdm

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
EPOCHS = 50

best_vloss = 1_000_000

for epoch in tqdm(range(EPOCHS)):
    model.train(True)
    avg_loss = train_one_epoch()

    print(f"Epoch {epoch + 1} loss: {avg_loss}")


  2%|▏         | 1/50 [01:17<1:03:35, 77.87s/it]

Epoch 1 loss: 0.6563315810635686


  4%|▍         | 2/50 [02:35<1:01:58, 77.47s/it]

Epoch 2 loss: 0.6017008032649755


  6%|▌         | 3/50 [03:52<1:00:33, 77.31s/it]

Epoch 3 loss: 0.577974382089451


  8%|▊         | 4/50 [05:10<59:32, 77.67s/it]  

Epoch 4 loss: 0.5650902739726007


 10%|█         | 5/50 [06:27<58:08, 77.52s/it]

Epoch 5 loss: 0.55233197088819


 12%|█▏        | 6/50 [07:45<56:49, 77.49s/it]

Epoch 6 loss: 0.5447721492964774


 14%|█▍        | 7/50 [09:02<55:28, 77.41s/it]

Epoch 7 loss: 0.5334153134608641


 16%|█▌        | 8/50 [10:19<54:10, 77.39s/it]

Epoch 8 loss: 0.5308127694297582


 18%|█▊        | 9/50 [11:36<52:48, 77.29s/it]

Epoch 9 loss: 0.5235685247462243


 20%|██        | 10/50 [12:53<51:30, 77.26s/it]

Epoch 10 loss: 0.518739445367828


 22%|██▏       | 11/50 [14:11<50:15, 77.32s/it]

Epoch 11 loss: 0.5141276411013678


 24%|██▍       | 12/50 [15:29<49:12, 77.69s/it]

Epoch 12 loss: 0.5135528722312301


 26%|██▌       | 13/50 [16:46<47:47, 77.49s/it]

Epoch 13 loss: 0.5064283803803846


 28%|██▊       | 14/50 [18:03<46:23, 77.32s/it]

Epoch 14 loss: 0.5042746822582558


 30%|███       | 15/50 [19:20<45:02, 77.20s/it]

Epoch 15 loss: 0.4995563494740054


 32%|███▏      | 16/50 [20:37<43:41, 77.11s/it]

Epoch 16 loss: 0.5009204463567585


 34%|███▍      | 17/50 [21:54<42:22, 77.03s/it]

Epoch 17 loss: 0.49759945820551366


 36%|███▌      | 18/50 [23:11<41:03, 76.98s/it]

Epoch 18 loss: 0.4917936873389408


 38%|███▊      | 19/50 [24:28<39:44, 76.93s/it]

Epoch 19 loss: 0.48732779605779797


 40%|████      | 20/50 [25:45<38:29, 76.98s/it]

Epoch 20 loss: 0.4881518838228658


 42%|████▏     | 21/50 [27:02<37:12, 76.98s/it]

Epoch 21 loss: 0.4876002741511911


 44%|████▍     | 22/50 [28:19<35:54, 76.96s/it]

Epoch 22 loss: 0.48383771430235356


 46%|████▌     | 23/50 [29:36<34:37, 76.93s/it]

Epoch 23 loss: 0.48519956215750426


 48%|████▊     | 24/50 [30:53<33:21, 76.99s/it]

Epoch 24 loss: 0.483124005026184


 50%|█████     | 25/50 [32:10<32:03, 76.95s/it]

Epoch 25 loss: 0.4782412914792076


 52%|█████▏    | 26/50 [33:27<30:47, 76.96s/it]

Epoch 26 loss: 0.47948009171523154


 54%|█████▍    | 27/50 [34:44<29:30, 76.98s/it]

Epoch 27 loss: 0.475331885041669


 56%|█████▌    | 28/50 [36:00<28:12, 76.93s/it]

Epoch 28 loss: 0.4733575980644673


 58%|█████▊    | 29/50 [37:17<26:55, 76.92s/it]

Epoch 29 loss: 0.4740801624720916


 60%|██████    | 30/50 [38:34<25:38, 76.93s/it]

Epoch 30 loss: 0.47044038644526154


 62%|██████▏   | 31/50 [39:51<24:21, 76.92s/it]

Epoch 31 loss: 0.47094915923662484


 64%|██████▍   | 32/50 [41:08<23:04, 76.91s/it]

Epoch 32 loss: 0.4668874393682927


 66%|██████▌   | 33/50 [42:25<21:47, 76.93s/it]

Epoch 33 loss: 0.4669450728688389


 68%|██████▊   | 34/50 [43:42<20:32, 77.03s/it]

Epoch 34 loss: 0.4652400676859543


 70%|███████   | 35/50 [44:59<19:15, 77.05s/it]

Epoch 35 loss: 0.46484355605207384


 72%|███████▏  | 36/50 [46:16<17:58, 77.06s/it]

Epoch 36 loss: 0.462912758695893


 74%|███████▍  | 37/50 [47:33<16:40, 76.98s/it]

Epoch 37 loss: 0.46123744710348547


 76%|███████▌  | 38/50 [48:50<15:23, 76.95s/it]

Epoch 38 loss: 0.4604658446041867


 78%|███████▊  | 39/50 [50:07<14:06, 76.98s/it]

Epoch 39 loss: 0.45747743477113545


 80%|████████  | 40/50 [51:24<12:49, 76.95s/it]

Epoch 40 loss: 0.45880209386814386


 82%|████████▏ | 41/50 [52:42<11:35, 77.33s/it]

Epoch 41 loss: 0.4534814398502931


 84%|████████▍ | 42/50 [54:00<10:19, 77.50s/it]

Epoch 42 loss: 0.45310874201823026


In [None]:
torch.save(model.state_dict(), 'user_tower.model')