In [1]:
import pandas as pd
from datasets import load_dataset

# Load the Books reviews dataset
dataset = load_dataset("cogsci13/Amazon-Reviews-2023-Books-Review", "default", trust_remote_code=True)

# Extract relevant fields
df = pd.DataFrame({
    'user_id': dataset['full']['user_id'],
    'item_id': dataset['full']['parent_asin'],
    'rating': dataset['full']['rating'],
    'timestamp': dataset['full']['timestamp']
})

df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
df_sorted = df.sort_values(by=['user_id', 'timestamp'])
test_df = df_sorted.groupby('user_id').tail(2)
train_df = df_sorted.drop(test_df.index)

Resolving data files:   0%|          | 0/33 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MatrixFactorization(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        super(MatrixFactorization, self).__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
    #     self._init_weights()

    # def _init_weights(self):
    #     nn.init.normal_(self.user_embedding.weight, std=0.01)
    #     nn.init.normal_(self.item_embedding.weight, std=0.01)

    def forward(self, user_indices, item_indices):
        user_vecs = self.user_embedding(user_indices)
        item_vecs = self.item_embedding(item_indices)
        scores = torch.sum(user_vecs * item_vecs, dim=1)
        return torch.sigmoid(scores)

## Loss Functions

$$
\mathcal{L}_{\text{MSE}} = \sum_{(u,i)} \mathbf{1}_{ui} \left( \frac{r_{ui} - 1}{R_{\max} - 1} - \sigma(\mathbf{p}_u \cdot \mathbf{q}_i) \right)^2 + \lambda_U \sum_u \|\mathbf{p}_u\|^2 + \lambda_I \sum_i \|\mathbf{q}_i\|^2
$$

In [None]:
def mse_loss(model, user_indices, item_indices, ratings, lambda_reg=1e-4, r_max=5):
    preds = model(user_indices, item_indices)
    ratings_norm = (ratings - 1) / (r_max - 1)
    mse = F.mse_loss(preds, ratings_norm)
    reg_term = lambda_reg * (model.user_embedding.weight.norm(2).pow(2) + model.item_embedding.weight.norm(2).pow(2))
    return mse + reg_term

$$
\mathcal{L}_{\text{BPR}} = \sum_{(u,i,j)} \ln \sigma(\mathbf{p}_u \cdot \mathbf{q}_i - \mathbf{p}_u \cdot \mathbf{q}_j) - \lambda \|\Theta\|^2
$$

In [3]:
def bpr_loss(model, user_indices, pos_item_indices, neg_item_indices, lambda_reg=1e-4):
    user_vecs = model.user_embedding(user_indices)
    pos_item_vecs = model.item_embedding(pos_item_indices)
    neg_item_vecs = model.item_embedding(neg_item_indices)
    pos_scores = torch.sum(user_vecs * pos_item_vecs, dim=1)
    neg_scores = torch.sum(user_vecs * neg_item_vecs, dim=1)
    loss = -torch.mean(torch.log(torch.sigmoid(pos_scores - neg_scores)))
    reg_term = lambda_reg * (user_vecs.norm(2).pow(2) + pos_item_vecs.norm(2).pow(2) + neg_item_vecs.norm(2).pow(2))
    return loss + reg_term

$$
\mathcal{L}_{\text{AU}} = \mathbb{E}_{(u,i)} \|\mathbf{p}_u - \mathbf{q}_i\|^2 + \lambda \left( 
\log \mathbb{E}_{(u,v)} e^{-\frac{1}{2}\|\mathbf{p}_u - \mathbf{p}_v\|^2} + 
\log \mathbb{E}_{(i,j)} e^{-\frac{1}{2}\|\mathbf{q}_i - \mathbf{q}_j\|^2} \right)
$$

In [None]:
def alignment_uniformity_loss(model, user_item_pairs, lambda_reg=1e-2):
    user_vecs = model.user_embedding(user_item_pairs[:, 0])
    item_vecs = model.item_embedding(user_item_pairs[:, 1])
    
    alignment = torch.mean((user_vecs - item_vecs).pow(2).sum(dim=1))

    user_norms = torch.cdist(model.user_embedding.weight, model.user_embedding.weight, p=2).pow(2)
    item_norms = torch.cdist(model.item_embedding.weight, model.item_embedding.weight, p=2).pow(2)

    uniformity_user = torch.log(torch.exp(-0.5 * user_norms).mean())
    uniformity_item = torch.log(torch.exp(-0.5 * item_norms).mean())

    return alignment + lambda_reg * (uniformity_user + uniformity_item)