In [1]:
import pandas as pd
import torch
import numpy as np

from sklearn.preprocessing import OneHotEncoder, StandardScaler
from torch import nn
from torch.utils.data import DataLoader, Dataset
from src.utils import read_pickles
import torch.optim as optim

pd.options.display.max_columns = 500

In [3]:
movies, users, ratings = read_pickles("../../data/ml-1m-after_eda/")

In [4]:
movies.head()

Unnamed: 0,MovieID,Title,Genres,Year,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western
0,1,Toy Story,"[Animation, Children's, Comedy]",1995,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0
1,2,Jumanji,"[Adventure, Children's, Fantasy]",1995,0,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
2,3,Grumpier Old Men,"[Comedy, Romance]",1995,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0
3,4,Waiting to Exhale,"[Comedy, Drama]",1995,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0
4,5,Father of the Bride Part II,[Comedy],1995,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0


In [5]:
ratings.head()

Unnamed: 0,UserID,MovieID,Rating,Timestamp,Datetime,Date
0,1,1193,5,978300760,2000-12-31 22:12:40,2000-12-31
1,1,661,3,978302109,2000-12-31 22:35:09,2000-12-31
2,1,914,3,978301968,2000-12-31 22:32:48,2000-12-31
3,1,3408,4,978300275,2000-12-31 22:04:35,2000-12-31
4,1,2355,5,978824291,2001-01-06 23:38:11,2001-01-06


1.1. Merge DataFrames

In [6]:
movie_rating = pd.merge(ratings, movies, on='MovieID')
movie_rating_user = pd.merge(movie_rating, users, on='UserID')

# Sort by UserID and Timestamp to maintain sequence order
merged_df = movie_rating_user.sort_values(by=['Timestamp'])

In [7]:
merged_df.head()

Unnamed: 0,UserID,MovieID,Rating,Timestamp,Datetime,Date,Title,Genres,Year,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,Gender,Age,Occupation,Zip-Code,State,Latitude,Longitude
456790,6040,858,4,956703932,2000-04-25 23:05:32,2000-04-25,"Godfather, The","[Action, Crime, Drama]",1972,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147
456672,6040,593,5,956703954,2000-04-25 23:05:54,2000-04-25,"Silence of the Lambs, The","[Drama, Thriller]",1991,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,M,25,6,11106,NY,40.762012,-73.93147
456732,6040,2384,4,956703954,2000-04-25 23:05:54,2000-04-25,Babe: Pig in the City,"[Children's, Comedy]",1998,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147
456641,6040,1961,4,956703977,2000-04-25 23:06:17,2000-04-25,Rain Man,[Drama],1988,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147
456842,6040,2019,5,956703977,2000-04-25 23:06:17,2000-04-25,Seven Samurai (The Magnificent Seven) (Shichin...,"[Action, Drama]",1954,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147


In [8]:
merged_df['watch_year'] = merged_df['Datetime'].dt.year
merged_df['watch_month'] = merged_df['Datetime'].dt.month
merged_df['watch_day'] = merged_df['Datetime'].dt.day
merged_df['watch_hours'] = merged_df['Datetime'].dt.hour
merged_df['watch_is_weekend'] = ((merged_df['Datetime'].dt.day_of_week + 1) > 5).astype('int')

### Categories

In [9]:
one_hot = OneHotEncoder(sparse=False, drop=None)
one_hot.fit(merged_df[['Gender', 'Occupation', 'State']])



In [10]:
encoded_feature_names = one_hot.get_feature_names_out(['Gender', 'Occupation', 'State'])
transformed_categories = one_hot.transform(merged_df[['Gender', 'Occupation', 'State']])

# Create a new DataFrame with the encoded features
transformed_categories_df = pd.DataFrame(transformed_categories, columns=encoded_feature_names)

### Scaling

In [11]:
scaler = StandardScaler()

scaler.fit(merged_df[['Year', 'Age']])

In [12]:
scaled_feature_names = ['scaled_Year', 'scaled_Age']
scaled_categories = scaler.transform(merged_df[['Year', 'Age']])

# Create a new DataFrame with the encoded features
scaled_categories_df = pd.DataFrame(scaled_categories, columns=scaled_feature_names)

In [13]:
encoded_df = pd.concat([merged_df, transformed_categories_df, scaled_categories_df], axis=1)

In [14]:
encoded_df.head()

Unnamed: 0,UserID,MovieID,Rating,Timestamp,Datetime,Date,Title,Genres,Year,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,Gender,Age,Occupation,Zip-Code,State,Latitude,Longitude,watch_year,watch_month,watch_day,watch_hours,watch_is_weekend,Gender_F,Gender_M,Occupation_0,Occupation_1,Occupation_2,Occupation_3,Occupation_4,Occupation_5,Occupation_6,Occupation_7,Occupation_8,Occupation_9,Occupation_10,Occupation_11,Occupation_12,Occupation_13,Occupation_14,Occupation_15,Occupation_16,Occupation_17,Occupation_18,Occupation_19,Occupation_20,State_AK,State_AL,State_AR,State_AZ,State_CA,State_CO,State_CT,State_DC,State_DE,State_FL,State_GA,State_HI,State_IA,State_ID,State_IL,State_IN,State_KS,State_KY,State_LA,State_MA,State_MD,State_ME,State_MI,State_MN,State_MO,State_MS,State_MT,State_NC,State_ND,State_NE,State_NH,State_NJ,State_NM,State_NV,State_NY,State_OH,State_OK,State_OR,State_PA,State_PR,State_RI,State_SC,State_SD,State_TN,State_TX,State_UT,State_VA,State_VT,State_WA,State_WI,State_WV,State_WY,State_None,scaled_Year,scaled_Age
456790,6040,858,4,956703932,2000-04-25 23:05:32,2000-04-25,"Godfather, The","[Action, Crime, Drama]",1972,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147,2000,4,25,23,0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.787605,-0.998837
456672,6040,593,5,956703954,2000-04-25 23:05:54,2000-04-25,"Silence of the Lambs, The","[Drama, Thriller]",1991,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,M,25,6,11106,NY,40.762012,-73.93147,2000,4,25,23,0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.857294,-0.998837
456732,6040,2384,4,956703954,2000-04-25 23:05:54,2000-04-25,Babe: Pig in the City,"[Children's, Comedy]",1998,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147,2000,4,25,23,0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.717915,-0.998837
456641,6040,1961,4,956703977,2000-04-25 23:06:17,2000-04-25,Rain Man,[Drama],1988,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147,2000,4,25,23,0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.857294,1.724109
456842,6040,2019,5,956703977,2000-04-25 23:06:17,2000-04-25,Seven Samurai (The Magnificent Seven) (Shichin...,"[Action, Drama]",1954,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,M,25,6,11106,NY,40.762012,-73.93147,2000,4,25,23,0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.717915,-0.403193


### Filter

In [15]:
filtered_encoded_df = encoded_df.drop(['MovieID', 'Timestamp', 'Datetime', 'Title', 'Genres', 'Year', 'Gender', 'Age', 'Occupation', 'Zip-Code', 'State', 'Latitude', 'Longitude'], axis=1)

In [16]:
filtered_encoded_df.head()

Unnamed: 0,UserID,Rating,Date,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,watch_year,watch_month,watch_day,watch_hours,watch_is_weekend,Gender_F,Gender_M,Occupation_0,Occupation_1,Occupation_2,Occupation_3,Occupation_4,Occupation_5,Occupation_6,Occupation_7,Occupation_8,Occupation_9,Occupation_10,Occupation_11,Occupation_12,Occupation_13,Occupation_14,Occupation_15,Occupation_16,Occupation_17,Occupation_18,Occupation_19,Occupation_20,State_AK,State_AL,State_AR,State_AZ,State_CA,State_CO,State_CT,State_DC,State_DE,State_FL,State_GA,State_HI,State_IA,State_ID,State_IL,State_IN,State_KS,State_KY,State_LA,State_MA,State_MD,State_ME,State_MI,State_MN,State_MO,State_MS,State_MT,State_NC,State_ND,State_NE,State_NH,State_NJ,State_NM,State_NV,State_NY,State_OH,State_OK,State_OR,State_PA,State_PR,State_RI,State_SC,State_SD,State_TN,State_TX,State_UT,State_VA,State_VT,State_WA,State_WI,State_WV,State_WY,State_None,scaled_Year,scaled_Age
456790,6040,4,2000-04-25,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.787605,-0.998837
456672,6040,5,2000-04-25,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,2000,4,25,23,0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.857294,-0.998837
456732,6040,4,2000-04-25,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.717915,-0.998837
456641,6040,4,2000-04-25,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.857294,1.724109
456842,6040,5,2000-04-25,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.717915,-0.403193


### Split data

In [17]:
def train_test_split(df, split_date):
    train = df[df["Date"] < split_date]
    test = df[df["Date"] >= split_date]
    train.drop('Date', axis=1, inplace=True)
    test.drop('Date', axis=1, inplace=True)
    return train, test

In [18]:
split_date = pd.to_datetime("2000-12-02").date()
train, test = train_test_split(filtered_encoded_df, split_date)
print(f"Train shape: {train.shape}")
print(f"Test shape: {test.shape}")

Train shape: (797116, 103)
Test shape: (203093, 103)


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train.drop('Date', axis=1, inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test.drop('Date', axis=1, inplace=True)


In [19]:
train.head()

Unnamed: 0,UserID,Rating,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,watch_year,watch_month,watch_day,watch_hours,watch_is_weekend,Gender_F,Gender_M,Occupation_0,Occupation_1,Occupation_2,Occupation_3,Occupation_4,Occupation_5,Occupation_6,Occupation_7,Occupation_8,Occupation_9,Occupation_10,Occupation_11,Occupation_12,Occupation_13,Occupation_14,Occupation_15,Occupation_16,Occupation_17,Occupation_18,Occupation_19,Occupation_20,State_AK,State_AL,State_AR,State_AZ,State_CA,State_CO,State_CT,State_DC,State_DE,State_FL,State_GA,State_HI,State_IA,State_ID,State_IL,State_IN,State_KS,State_KY,State_LA,State_MA,State_MD,State_ME,State_MI,State_MN,State_MO,State_MS,State_MT,State_NC,State_ND,State_NE,State_NH,State_NJ,State_NM,State_NV,State_NY,State_OH,State_OK,State_OR,State_PA,State_PR,State_RI,State_SC,State_SD,State_TN,State_TX,State_UT,State_VA,State_VT,State_WA,State_WI,State_WV,State_WY,State_None,scaled_Year,scaled_Age
456790,6040,4,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.787605,-0.998837
456672,6040,5,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,2000,4,25,23,0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.857294,-0.998837
456732,6040,4,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.717915,-0.998837
456641,6040,4,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.857294,1.724109
456842,6040,5,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,2000,4,25,23,0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.717915,-0.403193


In [20]:
def get_sequence_by_user_id(df, user_id):
    return df[df['UserID'] == user_id]

### Data preparation

In [22]:
train = train[:10000]
test = test[:5000]

In [23]:
def create_sequences(df, sequence_length):
    X = []
    y = []
    user_ids = df['UserID'].unique()
    for user_id in user_ids:
        user_data = df[df['UserID'] == user_id]
        for i in range(len(user_data) - sequence_length + 1):
            seq_X = user_data.drop(['UserID', 'Rating'], axis=1).iloc[i:i + sequence_length].values
            seq_y = user_data['Rating'].iloc[i:i + sequence_length].values
            X.append(seq_X)
            y.append(seq_y)
    return np.array(X), np.array(y)

X_train, y_train = create_sequences(train, 5)
X_test, y_test = create_sequences(test, 5)

In [24]:
X_train.shape

(9684, 5, 101)

Model

In [25]:
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

In [26]:
# Create a custom Dataset class
class MovieRatingDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
    

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size

    def forward(self, x):
        h_0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        c_0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(x, (h_0, c_0))
        out = self.fc(out[:, -1, :])  # Take the output from the last time step
        return out

In [28]:
train_dataset = MovieRatingDataset(X_train, y_train)
test_dataset = MovieRatingDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

input_size = X_train.shape[2]
hidden_size = 64
num_layers = 1
output_size = 5


model = LSTMModel(input_size, hidden_size, num_layers, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [1/25], Loss: 2.6785
Epoch [2/25], Loss: 0.5868
Epoch [3/25], Loss: 1.0892
Epoch [4/25], Loss: 1.0958
Epoch [5/25], Loss: 3.7688
Epoch [6/25], Loss: 0.6844
Epoch [7/25], Loss: 1.0635
Epoch [8/25], Loss: 1.0376
Epoch [9/25], Loss: 1.3222
Epoch [10/25], Loss: 3.6151
Epoch [11/25], Loss: 2.2600
Epoch [12/25], Loss: 0.6862
Epoch [13/25], Loss: 1.4435
Epoch [14/25], Loss: 1.2155
Epoch [15/25], Loss: 0.5735
Epoch [16/25], Loss: 0.6099
Epoch [17/25], Loss: 1.8790
Epoch [18/25], Loss: 0.7301
Epoch [19/25], Loss: 1.7629
Epoch [20/25], Loss: 0.4414
Epoch [21/25], Loss: 0.6976
Epoch [22/25], Loss: 1.4915
Epoch [23/25], Loss: 0.5792
Epoch [24/25], Loss: 0.4994
Epoch [25/25], Loss: 0.7807


In [29]:
model.eval()
with torch.no_grad():
    test_loss = 0
    for inputs, targets in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        test_loss += loss.item()
    print(f'Test Loss: {test_loss / len(test_loader):.4f}')

Test Loss: 1.1347
