# Import Packages and Mounting GDrive

In [None]:
import pandas as pd
import numpy as np
import jax.numpy as jnp
import jax

import matplotlib.pyplot as plt

from tqdm.auto import tqdm
import time

from sklearn.model_selection import train_test_split
from scipy import sparse

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
#Change directory here!
%cd /content/drive/MyDrive/Colab Notebooks


/content/drive/MyDrive/Colab Notebooks


# Reading Data

In [None]:
#Load Training Dataset
train_all = pd.read_csv('assignment_2_ratings_train.csv')
train_all

Unnamed: 0,user_id,anime_id,rating
0,20170,10794,6
1,24592,21995,5
2,18358,7054,10
3,59267,488,7
4,69313,30544,4
...,...,...,...
4436063,32872,15061,7
4436064,66206,20507,8
4436065,46386,834,7
4436066,10497,9159,7


In [None]:
#Shuffle and Split
train_all_shuffled = train_all.sample(frac = 1)
train, valid = train_test_split(train_all_shuffled, test_size=0.2)

In [None]:
#Convert to JAX numpy arrays
train_np = jnp.array(train)
val_np = jnp.array(valid)

# NMF Training

In [None]:
#Obtaining separate variables for each column in the training dataset
train_user_id, train_anime_id , train_actual_ratings = train_np[:,0],train_np[:,1],train_np[:,2]

In [None]:
#Generating the list of unique IDs
train_user_id_all = list(set(train['user_id']))
train_anime_id_all = list(set(train['anime_id']))

In [None]:
#Setting the Dimensions of the matrix
p = max(train_user_id_all) + 1
q = max(train_anime_id_all) + 1
r = 10
print(p,q)

73517 34476


Given that the maximum is used, there would be certain rows which do not correspond to any anime/user in the training set. They would not be updated during training.

In [None]:
# Initializing U and V, values are all non-negative. Mean and Scale determined by distribution of actual ratings and initial ratings
U_init = jnp.abs(np.random.normal(1,0.1,size = (p,r)))
V_init = jnp.abs(np.random.normal(1,0.1,size = (q,r)))

In [None]:
#To be used as a divisor for regularization later, denotes the number of parameters in U and V
num_elements_U = len(train_user_id_all) * 10
num_elements_V = len(train_anime_id_all) * 10

In [None]:
#Loss function with L2 regularization
reg_constant =  0.0008
@jax.jit
def loss(params,data):
  U,V = params[0], params[1]
  u_list, f_list, actual_ratings = data[0], data[1], data[2]
  pred_ratings = jnp.sum(U[u_list]*V[f_list], axis = 1)
  #MSE + L2 regularization Penalty
  return jnp.mean((pred_ratings-actual_ratings)**2) + (reg_constant/num_elements_U) * (jnp.linalg.norm(U[train_user_id_all,])) + (reg_constant/num_elements_V) * (jnp.linalg.norm(V[train_anime_id_all,]))

In [None]:
#Gradient Functions
U_grad = jax.jit(jax.grad(loss, argnums = 0))
V_grad = jax.jit(jax.grad(loss, argnums = 1))

In [None]:
#Getting columns from validation set
valid_user_id, valid_anime_id, valid_rating = np.array(val_np[:,0]), np.array(val_np[:,1]), np.array(val_np[:,2])

In [None]:
#set of unique user and anime ids
user_set = set(train['user_id'])
anime_set = set(train['anime_id'])

In [None]:
#returns 4 lists(categories), each containing index of x that belongs in the specified category
def split_category(x, user_set, anime_set):
  user_id, anime_id, rating = np.array(x[:,0]), np.array(x[:,1]), np.array(x[:,2])
  in_both = np.array([i for i in range(len(x)) if user_id[i] in user_set and anime_id[i] in anime_set])
  in_user = np.array([i for i in range(len(x)) if user_id[i] in user_set and anime_id[i] not in anime_set])
  in_anime = np.array([i for i in range(len(x)) if user_id[i] not in user_set and anime_id[i] in anime_set])
  in_none = np.array([i for i in range(len(x)) if user_id[i] not in user_set and anime_id[i] not in anime_set])
  return in_both, in_user, in_anime, in_none

In [None]:
#Obtaining the indices attached to each category for validation set
in_both, in_user, in_anime, in_none = split_category(val_np, user_set, anime_set)

In [None]:
#Mutually exclusive sets
len(in_both) + len(in_user) + len(in_anime) + len(in_none) - len(val_np)

0

In [None]:
#in_none is empty
len(in_none)

0

In [None]:
# Install jaxopt library to perform Projected Gradient Descent
!pip install jaxopt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jaxopt
  Downloading jaxopt-0.6-py3-none-any.whl (142 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.2/142.2 KB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxopt
Successfully installed jaxopt-0.6


In [None]:
# Import optimisation packages
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative

In [None]:
# Perform Projected Gradient Descent
# Constrained Optimisation Problem where U and V cannot be negative
pg = ProjectedGradient(fun=loss, projection=projection_non_negative, maxiter = 350)
pg_sol = pg.run(init_params = [U_init,V_init], data = (train_user_id, train_anime_id, train_actual_ratings))

In [None]:
#Obtain U and V from the Projected Gradient Descent
U, V = pg_sol.params[0], pg_sol.params[1]

In [None]:
#Checking Training Loss
loss(pg_sol.params, (train_user_id, train_anime_id, train_actual_ratings))

Array(1.0345982, dtype=float32)

In [None]:
#Sanity Check that matrices are non-negative
jnp.sum((V < 0).ravel())

Array(0, dtype=int32)

In [None]:
jnp.sum((U < 0).ravel())

Array(0, dtype=int32)

# Evaluation on Validation Set

In [None]:
#Obtaining the mean user and anime factor from rows which correspond to anime/user in training set
mean_user_factor = jnp.mean(U[train_user_id_all,], axis = 0)
mean_anime_factor = jnp.mean(V[train_anime_id_all,], axis = 0)

In [None]:
global_average_rating = jnp.mean(train_np[:,2])
global_average_rating

Array(7.808637, dtype=float32)

In [None]:
# Sanity check that dot product of mean_user_factor and mean_anime_factor is close to global average
jnp.sum(mean_user_factor*mean_anime_factor)

Array(7.245419, dtype=float32)

In [None]:
min(train['rating']),max(train['rating'])

(1, 10)

In [None]:
#in_none is left out as it is empty
# Ratings below 1 and above 10 are set as 1 and 10 respectively
def nmf_mse(x, in_both, in_user, in_anime):
  user_id, anime_id, rating = np.array(x[:,0]), np.array(x[:,1]), np.array(x[:,2])
  squared_error = 0
  if len(in_both)>0:
    pred_both = jnp.sum(U[user_id[in_both]]*V[anime_id[in_both]], axis = 1)
    pred_both = jnp.clip(pred_both,1,10)
    squared_error += jnp.sum((pred_both-np.array(rating[in_both]))**2)
  #Pred in User only
  if len(in_user)>0:
    pred_user = jnp.sum(U[user_id[in_user]]*mean_anime_factor, axis = 1)
    pred_user = jnp.clip(pred_user,1,10)
    squared_error += jnp.sum((pred_user-np.array(rating[in_user]))**2)
  #Pred in Anime only
  if len(in_anime)>0:
    pred_anime = jnp.sum(V[anime_id[in_anime]]*mean_user_factor, axis = 1)
    pred_anime = jnp.clip(pred_anime,1,10)
    squared_error += jnp.sum((pred_anime-np.array(rating[in_anime]))**2)
  return squared_error/len(x)

In [None]:
#Compute Validation MSE for NMF Model
nmf_mse(val_np,in_both, in_user, in_anime)

Array(1.4134694, dtype=float32)