In [1]:
import time
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm as tq
import gzip, pickle
import jax
from sklearn.model_selection import train_test_split

In [2]:
df = pd.read_csv("100k_csv/ratings.csv").sort_values(by='userId')

In [3]:
# Split the data into train and test sets
train_df, test_df = train_test_split(df, test_size=0.01, random_state=42)

# Get unique users and movies in train and test sets
users_in_train = train_df['userId'].unique()
users_in_test = test_df['userId'].unique()
movies_in_train = train_df['movieId'].unique()
movies_in_test = test_df['movieId'].unique()

# Convert to sets for efficient comparison
train_users_set = set(users_in_train)
test_users_set = set(users_in_test)
train_movies_set = set(movies_in_train)
test_movies_set = set(movies_in_test)

# Find users and movies to move from test to train
users_to_move = test_users_set - train_users_set  # Users only in the test set
movies_to_move = test_movies_set - train_movies_set  # Movies only in the test set

# Move users from test to train
if users_to_move:
    user_data_to_move = test_df[test_df['userId'].isin(users_to_move)]
    train_df = pd.concat([train_df, user_data_to_move])  # Move these users to train
    test_df = test_df[~test_df['userId'].isin(users_to_move)]  # Remove from test

# Move movies from test to train
if movies_to_move:
    movie_data_to_move = test_df[test_df['movieId'].isin(movies_to_move)]
    train_df = pd.concat([train_df, movie_data_to_move])  # Move these movies to train
    test_df = test_df[~test_df['movieId'].isin(movies_to_move)]  # Remove from test

# Get unique users and movies in train and test sets
users_in_train = train_df['userId'].unique()
users_in_test = test_df['userId'].unique()
movies_in_train = train_df['movieId'].unique()
movies_in_test = test_df['movieId'].unique()

# Convert to sets for efficient comparison
train_users_set = set(users_in_train)
test_users_set = set(users_in_test)
train_movies_set = set(movies_in_train)
test_movies_set = set(movies_in_test)

# Verify that all test users and movies are now in train using set intersection
all_test_users_in_train = test_users_set.issubset(train_users_set)
all_test_movies_in_train = test_movies_set.issubset(train_movies_set)

# Output the result
print(f"All test users in train: {all_test_users_in_train}")
print(f"All test movies in train: {all_test_movies_in_train}")

# Check the result for the number of unique users and movies
print(f"Number of users in Train: {train_df['userId'].nunique()}")
print(f"Number of users in Test: {test_df['userId'].nunique()}")
print(f"Number of movies in Train: {train_df['movieId'].nunique()}")
print(f"Number of movies in Test: {test_df['movieId'].nunique()}")

All test users in train: True
All test movies in train: True
Number of users in Train: 610
Number of users in Test: 345
Number of movies in Train: 9724
Number of movies in Test: 755


In [4]:
len(train_df), len(test_df), len(train_df) + len(test_df) == len(df)

(99866, 970, True)

In [5]:
def split_users_into_chunks(df, num_chunks):
    """
    Split a DataFrame into chunks based on unique users, 
    with users shuffled in each call.

    Parameters:
    df (pd.DataFrame): The input DataFrame containing user ratings.
    num_chunks (int): The number of desired DataFrames (chunks).

    Returns:
    list: A list of lists, each containing indices of rows for a unique group of users.
    """
    # Step 1: Get unique users and shuffle them
    unique_users = df['userId'].unique()
    np.random.shuffle(unique_users)

    # Step 2: Split users into groups
    user_groups = np.array_split(unique_users, num_chunks)

    # Step 3: Create a list of indices based on user groups
    index_groups = []
    for group in user_groups:
        # Get the indices of the rows corresponding to the current group of users
        indices = df[df['userId'].isin(group)].index.tolist()
        
        index_groups.append(indices)

    return index_groups, user_groups

In [6]:
import jax
import jax.numpy as jnp
from jax import grad

# Objective Function: J(U, V)
@jax.jit
def loss(U, V, b_u, b_i, mu, R, lam):
    # Compute the predicted rating matrix including biases
    R_hat = mu + b_u[:, None] + b_i[None, :] + jnp.dot(U, V.T)
    
    # Compute the error matrix
    E = R - R_hat
    E = jnp.where(R == 0, 0, E)
    
    # Calculate the loss (mean squared error + regularization)
    squared_error = jnp.sum(E**2)/jnp.sum(R > 0)
    regularization = lam * (jnp.sum(U**2) + jnp.sum(V**2) + jnp.sum(b_u**2) + jnp.sum(b_i**2))
    
    # Return the total objective value
    return squared_error + regularization

# Gradients with respect to U, V, b_u, and b_i using JAX
grad_U = jax.jit(grad(loss, argnums=0))  # Gradient with respect to U
grad_V = jax.jit(grad(loss, argnums=1))  # Gradient with respect to V
grad_b_u = jax.jit(grad(loss, argnums=2))  # Gradient with respect to user biases
grad_b_i = jax.jit(grad(loss, argnums=3))  # Gradient with respect to item biases

# SGD update step
@jax.jit
def sgd_step(U, V, b_u, b_i, mu, R, lam, alpha):
    dU = grad_U(U, V, b_u, b_i, mu, R, lam)  # Gradient wrt U
    dV = grad_V(U, V, b_u, b_i, mu, R, lam)  # Gradient wrt V
    db_u = grad_b_u(U, V, b_u, b_i, mu, R, lam)  # Gradient wrt user biases
    db_i = grad_b_i(U, V, b_u, b_i, mu, R, lam)  # Gradient wrt item biases
    
    # Update the parameters U, V, b_u, and b_i
    U_new = U - alpha * dU
    V_new = V - alpha * dV
    b_u_new = b_u - alpha * db_u
    b_i_new = b_i - alpha * db_i
    
    return U_new, V_new, b_u_new, b_i_new

def matrix_factorization_sgd(weights, R, mu, lam, alpha, iterations):
    U, V, b_u, b_i = weights
    
    def body(carry, _):
        U, V, b_u, b_i = carry
        # Perform the SGD update
        U, V, b_u, b_i = sgd_step(U, V, b_u, b_i, mu, R, lam, alpha)
        # Return the updated values and the carry
        return (U, V, b_u, b_i), None
    
    carry = (U, V, b_u, b_i)
    
    # Use lax.scan to perform the updates
    final_carry, _ = jax.lax.scan(body, carry, jnp.arange(iterations))
    
    # Final carry will contain the updated U, V, b_u, b_i
    return final_carry

def train(df, num_chunks):
    # Parameters
    K = 10  # Number of latent factors
    lam = 0.1  # Regularization strength
    alpha = 0.1  # Learning rate
    iterations = 1000  # Number of iterations
    number_of_epochs = 5
    num_users, num_items = df.userId.nunique(), df.movieId.nunique()
    key = jax.random.PRNGKey(seed=42)
    u_key, v_key, b_u_key, b_i_key = jax.random.split(key, 4)
    
    jit_loss = jax.jit(loss)
    U = jax.random.normal(u_key, shape=(num_users, K)) * 0.01
    V = jax.random.normal(v_key, shape=(num_items, K)) * 0.01
    b_u = jnp.zeros(num_users)  # Initialize user biases
    b_i = jnp.zeros(num_items)  # Initialize item biases
    mu = df['rating'].mean()
    
    for epoch in range(number_of_epochs):
        total_loss = 0
        index_groups, user_groups = split_users_into_chunks(df, num_chunks=num_chunks)
        for i, (indices, user) in enumerate(zip(index_groups, user_groups)):
            # Pivot the DataFrame to create the user-movie matrix
            R = df.loc[indices].pivot(index='userId', columns='movieId', values='rating').fillna(0).to_numpy()
            movie = df.loc[indices].movieId.unique()
            weights = U[user], V[movie], b_u[user], b_i[movie]
            
            # Perform SGD step and update weights
            U_, V_, b_u_, b_i_ = matrix_factorization_sgd(weights, R, mu, lam, alpha, iterations)
            U = U.at[user].set(U_)  # Update U at the index 'user'
            b_u = b_u.at[user].set(b_u_)  # Update b_u at the index 'user'
            V = V.at[movie].set(V_)
            b_i = b_i.at[movie].set(b_i_)
            # Calculate the loss for this chunk and accumulate it
            current_loss = loss(U_, V_, b_u_, b_i_, mu, R, lam)
            
            if i % 10 == 0:
                print(f"Steps {(epoch + 1) * (i + 1) * iterations}, Loss: {current_loss}")
        
    
    return U, V, b_u, b_i

In [None]:
del df

U, V, b_u, b_i = train(train_df, 1

Steps 50, Loss: 1.3080358505249023
Steps 550, Loss: 0.6956747174263
Steps 1050, Loss: 1.148828148841858
Steps 1550, Loss: 0.7033240795135498
Steps 2050, Loss: 1.100050687789917
Steps 100, Loss: 1.031233549118042
Steps 1100, Loss: 0.8425652980804443
Steps 2100, Loss: 0.7444784641265869
Steps 3100, Loss: 1.196338176727295
Steps 4100, Loss: 1.2229278087615967
Steps 150, Loss: 1.2250889539718628
Steps 1650, Loss: 1.141504168510437
Steps 3150, Loss: 0.838412880897522
Steps 4650, Loss: 0.8361217975616455
Steps 6150, Loss: 1.1083306074142456
Steps 200, Loss: 1.1075458526611328
Steps 2200, Loss: 1.2313172817230225


In [8]:
def compute_rmse(U, V, b_u, b_i, mu, test_df):
    # Initialize variables to accumulate the total squared error and count of predictions
    total_squared_error = 0.0
    count = 0
    
    # Iterate through each row in the test DataFrame
    for _, row in tq(test_df.iterrows()):
        user_id = int(row['userId'])
        movie_id = int(row['movieId'])
        true_rating = row['rating']
        
        # Make prediction using the matrix factorization model
        predicted_rating = mu + b_u[user_id] + b_i[movie_id] + jnp.dot(U[user_id], V[movie_id])
        
        # Calculate the squared error for this prediction
        squared_error = (predicted_rating - true_rating) ** 2
        
        # Accumulate the squared error and count the predictions
        total_squared_error += squared_error
        count += 1
    
    # Calculate the Mean Squared Error (MSE)
    mse = total_squared_error / count
    
    # Return the Root Mean Square Error (RMSE)
    return jnp.sqrt(mse)

mu = train_df['rating'].mean()
rmse_value = compute_rmse(U, V, b_u, b_i, mu, test_df[:10000])
print(f"RMSE: {rmse_value}")


970it [00:01, 713.19it/s]

RMSE: 1.0451021194458008



