> Marcel Fernández
> Email:marcelfernandez.serrano@ru.nl
Affiliation: Radboud University
Location: Nijmegen, Netherlands

>Alejandro Pastor Rubio
Email:alejandropastor.rubio@ru.nl
Affiliation: Radboud University
Location: Nijmegen, Netherlands
    
> Xuezheng Zhang
Email:xuezheng.zhang@ru.nl
Affiliation: Radboud University
Location: Nijmegen, Netherlands


# Setting Up Environment

In [None]:
import flax
import jax
import optax
import pandas as pd
import numpy as np
import gc
import pickle
import json
import math

from jax import numpy as jnp
from flax.training import train_state
from flax import linen

from google.colab import drive

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd drive/Shareddrives/InfoRetrievalProject/Project/

/content/drive/Shareddrives/InfoRetrievalProject/Project


# Data Loader

## Loading the ratings data

In [None]:
ratings_df = pd.read_csv("data/rating.csv")
ratings_df = ratings_df[ratings_df.rating!=-1]

print("Data : {}".format(ratings_df.shape))
ratings_df = ratings_df.sort_values(by='user_id').reset_index(drop=True)
ratings_df['anime_id'] = ratings_df['anime_id'].astype(int)

ratings_df.head(5)

Data : (6337241, 3)


Unnamed: 0,user_id,anime_id,rating
0,1,8074,10
1,1,11617,10
2,1,11757,10
3,1,15451,10
4,2,11771,10


## Loading the anime data

In [None]:
anime_df = pd.read_csv("data/anime.csv")
print("Anime : {}".format(anime_df.shape))
anime_df = anime_df.sort_values(by='anime_id').reset_index(drop=True)

anime_df.head()

Anime : (12294, 7)


Unnamed: 0,anime_id,name,genre,type,episodes,rating,members
0,1,Cowboy Bebop,"Action, Adventure, Comedy, Drama, Sci-Fi, Space",TV,26,8.82,486824
1,5,Cowboy Bebop: Tengoku no Tobira,"Action, Drama, Mystery, Sci-Fi, Space",Movie,1,8.4,137636
2,6,Trigun,"Action, Comedy, Sci-Fi",TV,26,8.32,283069
3,7,Witch Hunter Robin,"Action, Drama, Magic, Mystery, Police, Superna...",TV,26,7.36,64905
4,8,Beet the Vandel Buster,"Adventure, Fantasy, Shounen, Supernatural",TV,52,7.06,9848


## Merge Ratings and Anime Data


In [None]:
ratings_df = ratings_df.merge(anime_df, on="anime_id")

ratings_df = ratings_df.rename(columns={"rating_x":"user_rating", "rating_y": "average_anime_rating"})

ratings_df.head()


Unnamed: 0,user_id,anime_id,user_rating,name,genre,type,episodes,average_anime_rating,members
0,1,8074,10,Highschool of the Dead,"Action, Ecchi, Horror, Supernatural",TV,12,7.46,535892
1,1,11617,10,High School DxD,"Comedy, Demons, Ecchi, Harem, Romance, School",TV,12,7.7,398660
2,1,11757,10,Sword Art Online,"Action, Adventure, Fantasy, Game, Romance",TV,25,7.83,893100
3,1,15451,10,High School DxD New,"Action, Comedy, Demons, Ecchi, Harem, Romance,...",TV,12,7.87,266657
4,2,11771,10,Kuroko no Basket,"Comedy, School, Shounen, Sports",TV,25,8.46,338315


## Few Sanity Checks


In [None]:
print("Are there Null User IDs ? : {}".format(np.alltrue(np.isnan(ratings_df["user_id"].isnull().values))))
print("Are there Null ISBN ?     : {}".format(np.alltrue(ratings_df["anime_id"].isnull().values)))
print("Are there Null Ratings ?  : {}".format(np.alltrue(ratings_df["user_rating"].isnull().values)))

Are there Null User IDs ? : False
Are there Null ISBN ?     : False
Are there Null Ratings ?  : False


## Separate Users With Only One Review and Remove Them From Data


In [None]:
from collections import Counter

single_review_users, multiple_review_users = [], []
less_review_users, more_review_users = [], []
for user_id, cnt in Counter(ratings_df["user_id"].values).items():
    if cnt == 1:
        less_review_users.append(user_id)
    else:
        more_review_users.append(user_id)

print("Users with only single review : {}".format(len(less_review_users)))
print("Users with multiple reviews   : {}".format(len(more_review_users)))

ratings_df = ratings_df[ratings_df["user_id"].isin(more_review_users)]

Users with only single review : 3249
Users with multiple reviews   : 66351


## Create mapping dictionaries and shuffeling data


In [None]:
unique_items = ratings_df["anime_id"].unique()
item_to_idx = dict(zip(unique_items,range(len(unique_items))))
item_to_title = dict(list(zip(ratings_df["anime_id"].values, ratings_df["name"].values)))

# Mapping User_id
valores_unicos_user_id = list(ratings_df['user_id'].unique())
real_to_mapped_user_id = {}
mapped_to_real_user_id = {}
for i in range(len(valores_unicos_user_id)):
  real_to_mapped_user_id[valores_unicos_user_id[i]] = i
  mapped_to_real_user_id[i] = valores_unicos_user_id[i]

# Mapping anime_id
valores_unicos_anime_id = list(ratings_df['anime_id'].unique())
real_to_mapped_anime_id = {}
mapped_to_real_anime_id = {}
for i in range(len(valores_unicos_anime_id)):
  real_to_mapped_anime_id[valores_unicos_anime_id[i]] = i
  mapped_to_real_anime_id[i] = valores_unicos_anime_id[i]

ratings_df['user_id'] = ratings_df['user_id'].map(real_to_mapped_user_id)
ratings_df['anime_id'] = ratings_df['anime_id'].map(real_to_mapped_anime_id)
ratings_df = ratings_df.dropna(subset=['anime_id'])

ratings_df['anime_id'] = ratings_df['anime_id'].astype(int)
anime_df['anime_id'] = anime_df['anime_id'].map(real_to_mapped_anime_id)

ratings_df = ratings_df.sample(frac=1.0,random_state=123)
ratings_df.head()


Unnamed: 0,user_id,anime_id,user_rating,name,genre,type,episodes,average_anime_rating,members
6299446,66004,1017,10,Redline,"Action, Cars, Sci-Fi, Sports",Movie,1,8.33,109392
5225198,53838,1624,6,Kimi no Iru Machi: Tasogare Kousaten,"Drama, Romance, School, Shounen",OVA,2,7.33,25799
1688870,18160,946,8,Rainbow: Nisha Rokubou no Shichinin,"Drama, Historical, Seinen, Thriller",TV,26,8.64,139474
5266019,54182,3608,7,Yutori-chan,"Comedy, Slice of Life",ONA,25,6.25,3516
6159692,64266,906,5,Hitsugi no Chaika: Avenging Battle,"Action, Adventure, Comedy, Fantasy, Romance, S...",TV,10,7.33,91049


In [None]:
## CREATE VARIABLES
n_items = len(ratings_df['anime_id'].unique())
n_users = len(ratings_df['user_id'].unique()) # all users (we need to take care with this variable)
ratings = ratings_df["user_rating"].unique()

print("Unique Items : {}".format(n_items))
print("Unique Users : {}".format(n_users))
print("Ratings : {}".format(ratings))

Unique Items : 9926
Unique Users : 66351
Ratings : [10  6  8  7  5  9  3  4  2  1]


## Filtering out the first 500 users (To have unseen users after training)


In [None]:
# We keep the 500 user in another dataset, to evaluate our model

users_id_list = list(ratings_df["user_id"].unique())
first_500_users = users_id_list[:500]

# New DataFrames
ratings_df_unseen = ratings_df[ratings_df["user_id"].isin(first_500_users)]
ratings_df = ratings_df[~ratings_df["user_id"].isin(first_500_users)]

# Mostrar los DataFrames
print("Primeros 500 user_id:")
print(ratings_df_unseen.shape)

print("\nRestantes user_id:")
print(ratings_df.shape)

#ratings_df_unseen.to_csv("ratings_unseen.csv", index=False)


Primeros 500 user_id:
(147928, 9)

Restantes user_id:
(6186062, 9)


# Training The Model

## Train & Test Split

In [None]:
def train_test_split(df):
    """
    Splits a DataFrame into training and testing datasets for each unique user.

    Parameters:
        df (DataFrame): Input DataFrame containing at least the columns "user_id", "anime_id", and "user_rating".

    Returns:
        tuple: Four NumPy arrays:
            - X_train: Features for the training set (user_id, anime_id).
            - X_test: Features for the testing set (user_id, anime_id).
            - Y_train: Labels for the training set (user_rating).
            - Y_test: Labels for the testing set (user_rating).
    """
    unique_users = np.unique(df["user_id"].values)
    X_train, X_test, Y_train, Y_test = [], [], [], []
    for i, user_id in enumerate(unique_users):
        ratings_temp = df[df["user_id"] == user_id]
        if ratings_temp.shape[0]==1: ## If only one sample per user then give it to train set
            X_train.extend(ratings_temp[["user_id","anime_id"]].values.tolist())
            Y_train.append(ratings_temp["user_rating"].values[0])
        else:
            idx = int(ratings_temp.shape[0]* 0.9) ## 90% train and 10% test
            ## Populate train data
            X_train.extend(ratings_temp[["user_id","anime_id"]].values[:idx].tolist())
            Y_train.extend(ratings_temp["user_rating"].values[:idx].tolist())
            ## Populate test data
            X_test.extend(ratings_temp[["user_id","anime_id"]].values[idx:].tolist())
            Y_test.extend(ratings_temp["user_rating"].values[idx:].tolist())

        if (i+1)%10000==0:
            print("{} users completed.".format(i+1))

    return np.array(X_train), np.array(X_test), np.array(Y_train), np.array(Y_test)

%time X_train, X_test, Y_train, Y_test = train_test_split(ratings_df)

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

10000 users completed.
20000 users completed.
30000 users completed.
40000 users completed.
50000 users completed.
60000 users completed.
CPU times: user 18min 4s, sys: 6.8 s, total: 18min 10s
Wall time: 13min 34s


((5536847, 2), (649215, 2), (5536847,), (649215,))

### Sanity Checks

In [None]:
train_users = set(np.unique(X_train[:,0]))
test_users = set(np.unique(X_test[:,0]))

train_not_test = train_users.difference(test_users)
test_not_train = test_users.difference(train_users)

print("Users in Train but not Test : {}".format(len(train_not_test)))
print("Users in Test  but not Train : {}".format(len(test_not_train)))

print("Nan values in X_train: ", np.isnan(X_train).sum())
print("Nan values in Y_train: ",np.isnan(Y_train).sum())

print("Nan values in X_test: ", np.isnan(X_test).sum())
print("Nan values in Y_test: ",np.isnan(Y_test).sum())

print("Check that the max value of user id the length of the unique array")
print("Max train:", np.max(X_train[:,0]), "Max test:", np.max(X_test[:,0]), "Lenght array:",n_users)

print("Check that the max value of anime id is the length of the unique array")
print("Max train:", np.max(X_train[:,1]), "Max test:", np.max(X_test[:,1]), "Length array:",n_items)

Users in Train but not Test : 0
Users in Test  but not Train : 0
Nan values in X_train:  0
Nan values in Y_train:  0
Nan values in X_test:  0
Nan values in Y_test:  0
Check that the max value of user id the length of the unique array
Max train: 66350 Max test: 66350 Lenght array: 66351
Check that the max value of anime id is the length of the unique array
Max train: 9925 Max test: 9918 Length array: 9926


## Create Model using Embeddings

In [None]:
from flax import linen
from flax.training import train_state

n_factors = 32

class SimpleRecSystem(linen.Module):
    """
    A simple recommendation system using user and item embeddings.

    Attributes:
        n_users (int): Number of unique users.
        n_items (int): Number of unique items.
        n_factors (int): Number of latent factors for embeddings.
    """

    n_users = n_users
    n_items = n_items
    n_factors = n_factors

    def setup(self):
        """
        Initializes the embeddings for users and items.
        """
        self.user_embeddings = linen.Embed(self.n_users, self.n_factors, name="User Embeddings")
        self.item_embeddings = linen.Embed(self.n_items, self.n_factors, name="Item Embeddings")

    def __call__(self, X_batch):
        """
        Computes the interaction scores between users and items in the batch.

        Parameters:
            X_batch (ndarray): Batch of interactions with user and item indices.

        Returns:
            ndarray: Interaction scores after applying ReLU.
        """
        users = self.user_embeddings(X_batch[:,0])
        items = self.item_embeddings(X_batch[:,1])

        return linen.relu((users * items).sum(axis=1))

    def get_item_embeddings(self, item_ids):
        """
        Fetch embeddings for specific item IDs.

        Parameters:
            item_ids (ndarray): Array of item indices.

        Returns:
            ndarray: Item embeddings.
        """
        return self.item_embeddings(item_ids)

    def get_user_embeddings(self, user_ids):
        """
        Fetch embeddings for specific user IDs.

        Parameters:
            user_ids (ndarray): Array of user indices.

        Returns:
            ndarray: User embeddings.
        """
        return self.user_embeddings(user_ids)

In [None]:
from jax import numpy as jnp

seed = jax.random.PRNGKey(0)

rec_system = SimpleRecSystem()

params = rec_system.init(seed, jax.random.randint(seed, (100, 2), minval=1, maxval=20))

for layer_params in params["params"].items():
    print("Layer Name : {}".format(layer_params[0]))
    weights = layer_params[1]["embedding"]
    print("\tLayer Weights : {}".format(weights.shape))

Layer Name : User Embeddings
	Layer Weights : (66351, 32)
Layer Name : Item Embeddings
	Layer Weights : (9926, 32)


## Loss Function

In [None]:
@jax.jit
def MSELoss(s, y):
    """
    Computes the Mean Squared Error (MSE) loss.

    Parameters:
        s (ndarray): Model predictions.
        y (ndarray): Actual target values.

    Returns:
        float: Mean Squared Error.
    """
    return jnp.power(s - y, 2).mean()

def MAELoss(params, input_data, actual):
    """
    Computes the Mean Absolute Error (MAE) loss.

    Parameters:
        params (dict): Model parameters.
        input_data (ndarray): Input data to the model.
        actual (ndarray): Actual target values.

    Returns:
        float: Mean Absolute Error.
    """
    preds = rec_system.apply(params, input_data)
    return jnp.abs(actual.squeeze() - preds.squeeze()).mean()

## Training the Model


In [None]:
from jax import value_and_grad
from functools import partial

@partial(jax.jit, static_argnames=['loss_fn'])
def train_step(state, X, Y, loss_fn, validation):
    """
    Performs a single training step or skips updates if in validation mode.

    Parameters:
        state (TrainState): Current training state, containing model parameters and optimizer state.
        X (ndarray): Input features for the model.
        Y (ndarray): Target values.
        loss_fn (function): Loss function to compute the loss (e.g., MSELoss or MAELoss).
        validation (bool): Flag indicating whether the step is a validation step (no updates).

    Returns:
        tuple:
            - Updated state (or the same state if in validation mode).
            - Computed loss value.
    """
    @jax.jit
    def calculate_loss(state, params, X, Y):
        """
        Computes the loss for the given state, parameters, inputs, and targets.

        Parameters:
            state (TrainState): Current training state.
            params (dict): Model parameters.
            X (ndarray): Input features.
            Y (ndarray): Target values.

        Returns:
            float: Computed loss.
        """
        scores = state.apply_fn(params, X)
        return loss_fn(scores, Y)

    loss, grads = value_and_grad(calculate_loss, argnums=1)(state, state.params, X, Y)

    def update_state(_):
        """
        Updates the model state using the computed gradients.
        """
        return state.apply_gradients(grads=grads), loss

    def no_update(_):
        """
        Returns the current state and loss without updates.
        """
        return state, loss

    return jax.lax.cond(
        validation,
        no_update,
        update_state,
        operand=None
    )

In [None]:
from jax import value_and_grad
from tqdm import tqdm

def TrainModel(X, Y, X_val, Y_val, epochs, loss_fn, params, model, optimizer, batch_size=256):
    """
    Trains the model over a specified number of epochs, recording training and validation losses.

    Parameters:
        X (ndarray): Training input data.
        Y (ndarray): Training target values.
        X_val (ndarray): Validation input data.
        Y_val (ndarray): Validation target values.
        epochs (int): Number of training epochs.
        loss_fn (function): Loss function to compute training and validation losses.
        params (dict): Initial model parameters.
        model (linen.Module): The model to be trained.
        optimizer (optax.GradientTransformation): Optimizer for updating model parameters.
        batch_size (int, optional): Batch size for training. Defaults to 256.

    Returns:
        tuple:
            - model_state (TrainState): Final state of the trained model.
            - loss_list (list): List of training losses for each epoch.
            - validation_loss (list): List of validation losses for each epoch.
    """
    model_state = train_state.TrainState.create(apply_fn = model.apply, params = params,tx = optimizer)
    loss_list = []
    validation_loss = []
    for i in range(1, epochs+1):
            batches = jnp.arange((X.shape[0]//batch_size)+1) ### Batch Indices

            losses = [] ## Record loss of each batch
            for batch in tqdm(batches):
                if batch != batches[-1]:
                    start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
                else:
                    start, end = int(batch*batch_size), None

                X_batch, Y_batch = X[start:end], Y[start:end] ## Single batch of data

                model_state, loss = train_step(model_state, X_batch, Y_batch, loss_fn, False)

                losses.append(loss) ## Record Loss

            loss_value = jnp.array(losses).mean()
            loss_list.append(loss_value)
            # Train
            print("Train MSELoss : {:.3f}".format(loss_value))

            # Validation " TEST DATA "
            _, val_loss = train_step(model_state, X_val, Y_val, loss_fn, True)
            validation_loss.append(val_loss)
            print("Valid MSELoss : {:.3f}".format(val_loss))
            gc.collect()
    return model_state, loss_list, validation_loss

In [None]:
# Hyperparameters initialization
seed = jax.random.PRNGKey(0)
epochs=35
batch_size = 10000
learning_rate=0.01

rec_system = SimpleRecSystem()
params = rec_system.init(seed, jax.random.randint(seed, (1, 2), minval=1, maxval=20))

optimizer = optax.nadam(learning_rate=learning_rate)


In [None]:
final_state, train_loss, validation_loss = TrainModel(X_train,
                                                       Y_train,
                                                       X_test,
                                                       Y_test,
                                                       epochs,
                                                       MSELoss,
                                                       params,
                                                       rec_system,
                                                       optimizer,
                                                       batch_size)

100%|██████████| 554/554 [00:18<00:00, 30.41it/s]


Train MSELoss : 58.871
Valid MSELoss : 51.113


100%|██████████| 554/554 [00:10<00:00, 51.83it/s]


Train MSELoss : 37.295
Valid MSELoss : 34.159


100%|██████████| 554/554 [00:10<00:00, 54.99it/s]


Train MSELoss : 11.651
Valid MSELoss : 11.847


100%|██████████| 554/554 [00:11<00:00, 48.52it/s]


Train MSELoss : 9.898
Valid MSELoss : 9.910


100%|██████████| 554/554 [00:10<00:00, 50.90it/s]


Train MSELoss : 4.337
Valid MSELoss : 7.068


100%|██████████| 554/554 [00:11<00:00, 49.02it/s]


Train MSELoss : 3.989
Valid MSELoss : 6.233


100%|██████████| 554/554 [00:10<00:00, 53.27it/s]


Train MSELoss : 3.123
Valid MSELoss : 6.261


100%|██████████| 554/554 [00:11<00:00, 46.96it/s]


Train MSELoss : 3.161
Valid MSELoss : 6.167


100%|██████████| 554/554 [00:11<00:00, 49.42it/s]


Train MSELoss : 2.816
Valid MSELoss : 5.226


100%|██████████| 554/554 [00:11<00:00, 46.75it/s]


Train MSELoss : 3.106
Valid MSELoss : 5.719


100%|██████████| 554/554 [00:10<00:00, 55.00it/s]


Train MSELoss : 2.921
Valid MSELoss : 5.003


100%|██████████| 554/554 [00:11<00:00, 48.40it/s]


Train MSELoss : 2.738
Valid MSELoss : 5.257


100%|██████████| 554/554 [00:10<00:00, 54.37it/s]


Train MSELoss : 2.768
Valid MSELoss : 4.654


100%|██████████| 554/554 [00:11<00:00, 46.64it/s]


Train MSELoss : 2.474
Valid MSELoss : 5.249


100%|██████████| 554/554 [00:14<00:00, 37.13it/s]


Train MSELoss : 2.713
Valid MSELoss : 4.986


100%|██████████| 554/554 [00:11<00:00, 49.61it/s]


Train MSELoss : 2.612
Valid MSELoss : 4.226


100%|██████████| 554/554 [00:12<00:00, 45.25it/s]


Train MSELoss : 2.742
Valid MSELoss : 4.550


100%|██████████| 554/554 [00:10<00:00, 54.31it/s]


Train MSELoss : 2.532
Valid MSELoss : 4.108


100%|██████████| 554/554 [00:11<00:00, 47.19it/s]


Train MSELoss : 2.667
Valid MSELoss : 5.582


100%|██████████| 554/554 [00:11<00:00, 49.66it/s]


Train MSELoss : 2.681
Valid MSELoss : 3.752


100%|██████████| 554/554 [00:11<00:00, 46.27it/s]


Train MSELoss : 2.584
Valid MSELoss : 4.012


100%|██████████| 554/554 [00:10<00:00, 52.66it/s]


Train MSELoss : 2.423
Valid MSELoss : 3.607


100%|██████████| 554/554 [00:11<00:00, 46.55it/s]


Train MSELoss : 2.332
Valid MSELoss : 4.432


100%|██████████| 554/554 [00:10<00:00, 51.15it/s]


Train MSELoss : 2.370
Valid MSELoss : 3.490


100%|██████████| 554/554 [00:11<00:00, 46.34it/s]


Train MSELoss : 2.574
Valid MSELoss : 3.631


100%|██████████| 554/554 [00:10<00:00, 50.40it/s]


Train MSELoss : 2.520
Valid MSELoss : 3.508


100%|██████████| 554/554 [00:11<00:00, 50.13it/s]


Train MSELoss : 2.387
Valid MSELoss : 3.286


100%|██████████| 554/554 [00:11<00:00, 49.73it/s]


Train MSELoss : 2.276
Valid MSELoss : 3.499


100%|██████████| 554/554 [00:10<00:00, 54.81it/s]


Train MSELoss : 2.596
Valid MSELoss : 4.639


100%|██████████| 554/554 [00:11<00:00, 47.91it/s]


Train MSELoss : 2.883
Valid MSELoss : 3.797


100%|██████████| 554/554 [00:10<00:00, 52.10it/s]


Train MSELoss : 2.805
Valid MSELoss : 3.863


100%|██████████| 554/554 [00:10<00:00, 50.54it/s]


Train MSELoss : 2.263
Valid MSELoss : 3.009


100%|██████████| 554/554 [00:09<00:00, 57.45it/s]


Train MSELoss : 2.007
Valid MSELoss : 3.109


100%|██████████| 554/554 [00:10<00:00, 54.91it/s]


Train MSELoss : 2.184
Valid MSELoss : 3.159


100%|██████████| 554/554 [00:11<00:00, 46.40it/s]


Train MSELoss : 2.246
Valid MSELoss : 5.559


# Save the Model Data

In [None]:
# Converting dictionaries from int64 -> int, to make sure we can store them for the testing phase
converted_real_to_mapped_user_id = {int(key): value for key, value in real_to_mapped_user_id.items()}
converted_mapped_to_real_user_id = {key: int(value) for key, value in mapped_to_real_user_id.items()}
converted_real_to_mapped_anime_id = {int(key): value for key, value in real_to_mapped_anime_id.items()}
converted_mapped_to_real_anime_id = {key: int(value) for key, value in mapped_to_real_anime_id.items()}


In [None]:
# Storing dictionaries as json
import json

with open('real_to_mapped_user.json', 'w') as archivo:
    json.dump(converted_real_to_mapped_user_id, archivo, indent=4)

with open('mapped_to_real_user.json', 'w') as archivo:
    json.dump(converted_mapped_to_real_user_id, archivo, indent=4)

with open('real_to_mapped_anime.json', 'w') as archivo:
    json.dump(converted_real_to_mapped_anime_id, archivo, indent=4)

with open('mapped_to_real_anime.json', 'w') as archivo:
    json.dump(converted_mapped_to_real_anime_id, archivo, indent=4)

In [None]:
# Saving our datasets: The one used for training and the one that will be used for testing
ratings_df.to_csv("ratings_df.csv", index=False)
ratings_df_unseen.to_csv("ratings_df_unseen.csv", index=False)

In [None]:
X_train_df = pd.DataFrame(X_train)
Y_train_df = pd.DataFrame(Y_train)
X_test_df = pd.DataFrame(X_test)
Y_test_df = pd.DataFrame(Y_test)

X_train_df.to_csv("X_train.csv", index=False)
Y_train_df.to_csv("Y_train.csv", index=False)
X_test_df.to_csv("X_test.csv", index=False)
Y_test_df.to_csv("Y_test.csv", index=False)

print("Data has been successfully saved as CSV files.")


Data has been successfully saved as CSV files.


In [None]:
# Save model parameters and losses
with open("model_outputs.pkl", "wb") as f:
    pickle.dump({
        "params": final_state.params,
        "train_loss": train_loss,
        "test_loss": validation_loss
    }, f)

print("Model and losses saved successfully!")

Model and losses saved successfully!
