# Prepare data section

In [None]:
%pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import pandas as pd
import torch
import pytorch_lightning as pl
from tqdm import tqdm
import torchmetrics
import math
from urllib.request import urlretrieve
from zipfile import ZipFile
import os
import torch.nn as nn
import numpy as np
from math import sqrt

## Settings

In [None]:
WINDOW_SIZE = 20

## Data

In [None]:
urlretrieve("http://files.grouplens.org/datasets/movielens/ml-1m.zip", "movielens.zip")
ZipFile("movielens.zip", "r").extractall()

In [None]:
users = pd.read_csv(
    "ml-1m/users.dat",
    sep="::",
    names=["user_id", "sex", "age_group", "occupation", "zip_code"],
)

ratings = pd.read_csv(
    "ml-1m/ratings.dat",
    sep="::",
    names=["user_id", "movie_id", "rating", "unix_timestamp"],
)

movies = pd.read_csv(
    "ml-1m/movies.dat", sep="::", names=["movie_id", "title", "genres"], encoding="ISO-8859-1"
)

  return func(*args, **kwargs)


In [None]:
## Movies
movies["year"] = movies["title"].apply(lambda x: x[-5:-1])
movies.year = pd.Categorical(movies.year)

In [None]:

movies["year"] = movies.year.cat.codes
## Users
users.sex = pd.Categorical(users.sex)
users["sex"] = users.sex.cat.codes


users.age_group = pd.Categorical(users.age_group)
users["age_group"] = users.age_group.cat.codes


users.occupation = pd.Categorical(users.occupation)
users["occupation"] = users.occupation.cat.codes


users.zip_code = pd.Categorical(users.zip_code)
users["zip_code"] = users.zip_code.cat.codes

#Ratings
ratings['unix_timestamp'] = pd.to_datetime(ratings['unix_timestamp'],unit='s')


In [None]:
# Save primary csv's
if not os.path.exists('data'):
    os.makedirs('data')
    
    
users.to_csv("data/users.csv",index=False)
movies.to_csv("data/movies.csv",index=False)
ratings.to_csv("data/ratings.csv",index=False)

In [None]:
## Movies
movies["movie_id"] = movies["movie_id"].astype(str)
## Users
users["user_id"] = users["user_id"].astype(str)

##Ratings 
ratings["movie_id"] = ratings["movie_id"].astype(str)
ratings["user_id"] = ratings["user_id"].astype(str)

In [None]:
movies

Unnamed: 0,movie_id,title,genres,year
0,1,Toy Story (1995),Animation|Children's|Comedy,75
1,2,Jumanji (1995),Adventure|Children's|Fantasy,75
2,3,Grumpier Old Men (1995),Comedy|Romance,75
3,4,Waiting to Exhale (1995),Comedy|Drama,75
4,5,Father of the Bride Part II (1995),Comedy,75
...,...,...,...,...
3878,3948,Meet the Parents (2000),Comedy,80
3879,3949,Requiem for a Dream (2000),Drama,80
3880,3950,Tigerland (2000),Drama,80
3881,3951,Two Family House (2000),Drama,80


In [None]:
genres = [
    "Action",
    "Adventure",
    "Animation",
    "Children's",
    "Comedy",
    "Crime",
    "Documentary",
    "Drama",
    "Fantasy",
    "Film-Noir",
    "Horror",
    "Musical",
    "Mystery",
    "Romance",
    "Sci-Fi",
    "Thriller",
    "War",
    "Western",
]

for genre in genres:
    movies[genre] = movies["genres"].apply(
        lambda values: int(genre in values.split("|"))
    )


In [None]:
movies

Unnamed: 0,movie_id,title,genres,year,Action,Adventure,Animation,Children's,Comedy,Crime,...,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western
0,1,Toy Story (1995),Animation|Children's|Comedy,75,0,0,1,1,1,0,...,0,0,0,0,0,0,0,0,0,0
1,2,Jumanji (1995),Adventure|Children's|Fantasy,75,0,1,0,1,0,0,...,1,0,0,0,0,0,0,0,0,0
2,3,Grumpier Old Men (1995),Comedy|Romance,75,0,0,0,0,1,0,...,0,0,0,0,0,1,0,0,0,0
3,4,Waiting to Exhale (1995),Comedy|Drama,75,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
4,5,Father of the Bride Part II (1995),Comedy,75,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3878,3948,Meet the Parents (2000),Comedy,80,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
3879,3949,Requiem for a Dream (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3880,3950,Tigerland (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3881,3951,Two Family House (2000),Drama,80,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### Transform the movie ratings data into sequences

First, let's sort the the ratings data using the `unix_timestamp`, and then group the
`movie_id` values and the `rating` values by `user_id`.

The output DataFrame will have a record for each `user_id`, with two ordered lists
(sorted by rating datetime): the movies they have rated, and their ratings of these movies.

In [None]:
ratings_group = ratings.sort_values(by=["unix_timestamp"]).groupby("user_id")

ratings_data = pd.DataFrame(
    data={
        "user_id": list(ratings_group.groups.keys()),
        "movie_ids": list(ratings_group.movie_id.apply(list)),
        "ratings": list(ratings_group.rating.apply(list)),
        "timestamps": list(ratings_group.unix_timestamp.apply(list)),
    }
)


In [None]:
ratings_data

Unnamed: 0,user_id,movie_ids,ratings,timestamps
0,1,"[3186, 1721, 1270, 1022, 2340, 1836, 3408, 120...","[4, 4, 5, 5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, ...","[2000-12-31 22:00:19, 2000-12-31 22:00:55, 200..."
1,10,"[597, 858, 743, 1210, 1948, 2312, 3751, 1282, ...","[4, 3, 3, 4, 4, 5, 5, 5, 3, 3, 3, 5, 4, 4, 4, ...","[2000-12-31 00:59:35, 2000-12-31 00:59:35, 200..."
2,100,"[260, 1676, 1198, 541, 1210, 3948, 3536, 2567,...","[4, 3, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 2, 3, 4, ...","[2000-12-23 17:46:35, 2000-12-23 17:46:35, 200..."
3,1000,"[971, 260, 2990, 2973, 1210, 3068, 3153, 1198,...","[4, 5, 4, 3, 5, 5, 2, 5, 5, 4, 5, 4, 3, 5, 5, ...","[2000-11-24 04:36:06, 2000-11-24 04:36:06, 200..."
4,1001,"[1198, 1617, 2885, 3909, 3555, 1479, 3903, 394...","[4, 4, 4, 2, 2, 1, 4, 5, 5, 4, 4, 4, 4, 3, 4, ...","[2000-11-24 04:19:51, 2000-11-24 04:21:42, 200..."
...,...,...,...,...
6035,995,"[1894, 260, 247, 433, 170, 74, 912, 3097, 1265...","[2, 4, 5, 3, 3, 4, 4, 4, 3, 5, 5, 5, 5, 5, 5, ...","[2000-11-24 08:33:05, 2000-11-24 08:33:05, 200..."
6036,996,"[1347, 2146, 1961, 2741, 1210, 527, 1196, 1213...","[4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5, ...","[2000-11-24 07:48:52, 2000-11-24 07:48:52, 200..."
6037,997,"[1196, 2082, 3247, 2447, 2633, 2028, 593, 318,...","[4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4, ...","[2000-11-24 05:37:15, 2000-11-24 05:40:25, 200..."
6038,998,"[2266, 1264, 1097, 1641, 805, 1388, 1968, 3751...","[3, 4, 5, 5, 4, 3, 4, 3, 4, 4, 4, 4, 5, 5, 4, ...","[2000-11-24 05:24:59, 2000-11-24 05:26:33, 200..."


Now, let's split the `movie_ids` list into a set of sequences of a fixed length.
We do the same for the `ratings`. Set the `sequence_length` variable to change the length
of the input sequence to the model. You can also change the `step_size` to control the
number of sequences to generate for each user.

In [None]:
sequence_length = 16
step_size = 1


def create_sequences(values, window_size, step_size):
    sequences = []
    start_index = 0
    while True:
        end_index = start_index + window_size
        seq = values[start_index:end_index]
        if len(seq) < window_size:
            seq = values[-window_size:]
            if len(seq) == window_size:
                sequences.append(seq)
            break
        sequences.append(seq)
        start_index += step_size
    return sequences


ratings_data.movie_ids = ratings_data.movie_ids.apply(
    lambda ids: create_sequences(ids, sequence_length, step_size)
)

ratings_data.ratings = ratings_data.ratings.apply(
    lambda ids: create_sequences(ids, sequence_length, step_size)
)

del ratings_data["timestamps"]

In [None]:
ratings_data

Unnamed: 0,user_id,movie_ids,ratings
0,1,"[[3186, 1721, 1270, 1022, 2340, 1836, 3408, 12...","[[4, 4, 5, 5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4,..."
1,10,"[[597, 858, 743, 1210, 1948, 2312, 3751, 1282,...","[[4, 3, 3, 4, 4, 5, 5, 5, 3, 3, 3, 5, 4, 4, 4,..."
2,100,"[[260, 1676, 1198, 541, 1210, 3948, 3536, 2567...","[[4, 3, 4, 3, 4, 3, 1, 1, 5, 4, 4, 3, 2, 3, 4,..."
3,1000,"[[971, 260, 2990, 2973, 1210, 3068, 3153, 1198...","[[4, 5, 4, 3, 5, 5, 2, 5, 5, 4, 5, 4, 3, 5, 5,..."
4,1001,"[[1198, 1617, 2885, 3909, 3555, 1479, 3903, 39...","[[4, 4, 4, 2, 2, 1, 4, 5, 5, 4, 4, 4, 4, 3, 4,..."
...,...,...,...
6035,995,"[[1894, 260, 247, 433, 170, 74, 912, 3097, 126...","[[2, 4, 5, 3, 3, 4, 4, 4, 3, 5, 5, 5, 5, 5, 5,..."
6036,996,"[[1347, 2146, 1961, 2741, 1210, 527, 1196, 121...","[[4, 3, 5, 3, 5, 5, 5, 5, 4, 2, 5, 5, 5, 4, 5,..."
6037,997,"[[1196, 2082, 3247, 2447, 2633, 2028, 593, 318...","[[4, 3, 3, 3, 2, 5, 5, 5, 4, 4, 5, 4, 4, 3, 4,..."
6038,998,"[[2266, 1264, 1097, 1641, 805, 1388, 1968, 375...","[[3, 4, 5, 5, 4, 3, 4, 3, 4, 4, 4, 4, 5, 5, 4,..."


After that, we process the output to have each sequence in a separate records in
the DataFrame. In addition, we join the user features with the ratings data.

In [None]:
ratings_data_movies = ratings_data[["user_id", "movie_ids"]].explode(
    "movie_ids", ignore_index=True
)
ratings_data_rating = ratings_data[["ratings"]].explode("ratings", ignore_index=True)
ratings_data_transformed = pd.concat([ratings_data_movies, ratings_data_rating], axis=1)
ratings_data_transformed = ratings_data_transformed.join(
    users.set_index("user_id"), on="user_id"
)

In [None]:
ratings_data_transformed

Unnamed: 0,user_id,movie_ids,ratings,sex,age_group,occupation,zip_code
0,1,"[3186, 1721, 1270, 1022, 2340, 1836, 3408, 120...","[4, 4, 5, 5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, 5]",0,0,10,1588
1,1,"[1721, 1270, 1022, 2340, 1836, 3408, 1207, 280...","[4, 5, 5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, 5, 5]",0,0,10,1588
2,1,"[1270, 1022, 2340, 1836, 3408, 1207, 2804, 260...","[5, 5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, 5, 5, 5]",0,0,10,1588
3,1,"[1022, 2340, 1836, 3408, 1207, 2804, 260, 720,...","[5, 3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, 5, 5, 5, 4]",0,0,10,1588
4,1,"[2340, 1836, 3408, 1207, 2804, 260, 720, 1193,...","[3, 5, 4, 4, 5, 4, 3, 5, 4, 4, 4, 5, 5, 5, 4, 5]",0,0,10,1588
...,...,...,...,...,...,...,...
915644,999,"[79, 2875, 2316, 2165, 361, 2688, 24, 2264, 19...","[3, 4, 3, 1, 3, 3, 3, 2, 1, 3, 2, 3, 3, 4, 2, 3]",1,2,15,2128
915645,999,"[2875, 2316, 2165, 361, 2688, 24, 2264, 1959, ...","[4, 3, 1, 3, 3, 3, 2, 1, 3, 2, 3, 3, 4, 2, 3, 3]",1,2,15,2128
915646,999,"[2316, 2165, 361, 2688, 24, 2264, 1959, 2676, ...","[3, 1, 3, 3, 3, 2, 1, 3, 2, 3, 3, 4, 2, 3, 3, 2]",1,2,15,2128
915647,999,"[2165, 361, 2688, 24, 2264, 1959, 2676, 2540, ...","[1, 3, 3, 3, 2, 1, 3, 2, 3, 3, 4, 2, 3, 3, 2, 2]",1,2,15,2128


In [None]:

ratings_data_transformed.movie_ids = ratings_data_transformed.movie_ids.apply(
    lambda x: ",".join(x)
)
ratings_data_transformed.ratings = ratings_data_transformed.ratings.apply(
    lambda x: ",".join([str(v) for v in x])
)

del ratings_data_transformed["zip_code"]

ratings_data_transformed.rename(
    columns={"movie_ids": "sequence_movie_ids", "ratings": "sequence_ratings"},
    inplace=True,
)

In [None]:
ratings_data_transformed.shape

(915649, 6)

With `sequence_length` of 4 and `step_size` of 2, we end up with 498,623 sequences.

Finally, we split the data into training and testing splits, with 85% and 15% of
the instances, respectively, and store them to CSV files.

In [None]:
random_selection = np.random.rand(len(ratings_data_transformed.index)) <= 0.85
train_data = ratings_data_transformed[random_selection]
test_data = ratings_data_transformed[~random_selection]

train_data.to_csv("data/train_data.csv", index=False, sep=",")
test_data.to_csv("data/test_data.csv", index=False, sep=",")

In [None]:
test_data

Unnamed: 0,user_id,sequence_movie_ids,sequence_ratings,sex,age_group,occupation
1,1,"1721,1270,1022,2340,1836,3408,1207,2804,260,72...",4553544543544455,0,0,10
5,1,"1836,3408,1207,2804,260,720,1193,919,608,2692,...",5445435444555454,0,0,10
10,1,"720,1193,919,608,2692,1961,2028,3105,938,1035,...",3544455545454543,0,0,10
12,1,"919,608,2692,1961,2028,3105,938,1035,1962,1028...",4445554545454354,0,0,10
14,1,"2692,1961,2028,3105,938,1035,1962,1028,2018,15...",4555454545435444,0,0,10
...,...,...,...,...,...,...
915596,999,"605,22,253,3259,1124,1183,271,1598,3174,3409,2...",2423323244334344,1,2,15
915610,999,"1589,3791,225,507,3173,3176,354,524,280,2447,1...",4443344443124343,1,2,15
915619,999,"2447,1515,724,266,450,2975,371,382,1027,1442,2...",3124343344343434,1,2,15
915643,999,"207,79,2875,2316,2165,361,2688,24,2264,1959,26...",4343133321323342,1,2,15


# BST Implementation and training

In [None]:
import pandas as pd
import torch
import pytorch_lightning as pl
from tqdm import tqdm
import torchmetrics
import math
from urllib.request import urlretrieve
from zipfile import ZipFile
import os
import torch.nn as nn
import numpy as np

In [None]:
users = pd.read_csv(
    "data/users.csv",
    sep=",",
)

ratings = pd.read_csv(
    "data/ratings.csv",
    sep=",",
)

movies = pd.read_csv(
    "data/movies.csv", sep=","
)

In [None]:
users.max()

user_id       6040
sex              1
age_group        6
occupation      20
zip_code      3438
dtype: int64

In [None]:
movies.max()

movie_id               3952
title       eXistenZ (1999)
genres              Western
year                     80
dtype: object

## Pytorch dataset

In [None]:
import pandas as pd
import torch
import torch.utils.data as data
import ast
from torch.nn.utils.rnn import pad_sequence

class MovieDataset(data.Dataset):
    """Movie dataset."""

    def __init__(
        self, ratings_file,test=False
    ):
        """
        Args:
            csv_file (string): Path to the csv file with user,past,future.
        """
        self.ratings_frame = pd.read_csv(
            ratings_file,
            delimiter=",",
            # iterator=True,
        )
        self.test = test

    def __len__(self):
        return len(self.ratings_frame)

    def __getitem__(self, idx):
        data = self.ratings_frame.iloc[idx]
        user_id = data.user_id
        
        movie_history = eval(data.sequence_movie_ids)
        movie_history_ratings = eval(data.sequence_ratings)
        target_movie_id = movie_history[-1:][0]
        target_movie_rating = movie_history_ratings[-1:][0]
        
        movie_history = torch.LongTensor(movie_history[:-1])
        movie_history_ratings = torch.LongTensor(movie_history_ratings[:-1])

        
        
        sex = data.sex
        age_group = data.age_group
        occupation = data.occupation
        
        return user_id, movie_history, target_movie_id,  movie_history_ratings, target_movie_rating, sex, age_group, occupation

In [None]:
genres = [
    "Action",
    "Adventure",
    "Animation",
    "Children's",
    "Comedy",
    "Crime",
    "Documentary",
    "Drama",
    "Fantasy",
    "Film-Noir",
    "Horror",
    "Musical",
    "Mystery",
    "Romance",
    "Sci-Fi",
    "Thriller",
    "War",
    "Western",
]

for genre in genres:
    movies[genre] = movies["genres"].apply(
        lambda values: int(genre in values.split("|"))
    )
    
sequence_length = 8

In [None]:
class LSTRDataset(data.Dataset):
    """Movie dataset for long-short term sequences."""

    def __init__(
        self, ratings_file,test=False
    ):
        """
        Args:
            csv_file (string): Path to the csv file with user,past,future.
        """
        self.ratings_frame = pd.read_csv(
            ratings_file,
            delimiter=",",
            # iterator=True,
        )
        self.test = test

    def __len__(self):
        return len(self.ratings_frame)

    def __getitem__(self, idx):
        data = self.ratings_frame.iloc[idx]
        user_id = data.user_id
        movie_history = eval(data.sequence_movie_ids)
        movie_history_ratings = eval(data.sequence_ratings)
        target_movie_id = movie_history[-1:][0]
        target_movie_rating = movie_history_ratings[-1:][0]

        movie_long_history = torch.LongTensor(movie_history[:12])
        movie_history = torch.LongTensor(movie_history[12:-1])
        movie_history_long_ratings = torch.LongTensor(movie_history_ratings[:12])
        movie_history_ratings = torch.LongTensor(movie_history_ratings[12:-1])

        
        
        sex = data.sex
        age_group = data.age_group
        occupation = data.occupation
        
        return user_id, movie_long_history, movie_history, target_movie_id,  movie_history_long_ratings, movie_history_ratings, target_movie_rating, sex, age_group, occupation

In [None]:
class PositionalEmbedding(nn.Module):
    """
    Computes positional embedding following "Attention is all you need"
    """

    def __init__(self, max_len, d_model):
        super().__init__()

        # Compute the positional encodings once in log space.
        self.pe = nn.Embedding(max_len, d_model)

    def forward(self, x):
        batch_size = x.size(0)
        return self.pe.weight.unsqueeze(0).repeat(batch_size, 1, 1)


class BST(pl.LightningModule):
    def __init__(
        self, args=None,
    ):
        super().__init__()
        super(BST, self).__init__()
        
        self.save_hyperparameters()
        self.args = args
        #-------------------
        # Embedding layers
        ##Users 
        self.embeddings_user_id = nn.Embedding(
            int(users.user_id.max())+1, int(math.sqrt(users.user_id.max()))+1
        )
        ###Users features embeddings
        self.embeddings_user_sex = nn.Embedding(
            len(users.sex.unique()), int(math.sqrt(len(users.sex.unique())))
        )
        self.embeddings_age_group = nn.Embedding(
            len(users.age_group.unique()), int(math.sqrt(len(users.age_group.unique())))
        )
        self.embeddings_user_occupation = nn.Embedding(
            len(users.occupation.unique()), int(math.sqrt(len(users.occupation.unique())))
        )
        self.embeddings_user_zip_code = nn.Embedding(
            len(users.zip_code.unique()), int(math.sqrt(len(users.sex.unique())))
        )
        
        ##Movies
        self.embeddings_movie_id = nn.Embedding(
            int(movies.movie_id.max())+1, int(math.sqrt(movies.movie_id.max()))+1
        )
        
        ###Movies features embeddings
        genre_vectors = movies[genres].to_numpy()
        self.embeddings_movie_genre = nn.Embedding(
            genre_vectors.shape[0], genre_vectors.shape[1]
        )
        
        
        
        self.embeddings_movie_year = nn.Embedding(
            len(movies.year.unique()), int(math.sqrt(len(movies.year.unique())))
        )

        self.long_positional_embedding = PositionalEmbedding(12, 9)
        self.long_encoding_layer = nn.TransformerEncoderLayer(72, 3, dropout=0.2)
        self.positional_embedding = PositionalEmbedding(4, 9)
        
        # Network
        self.transfomerlayer = nn.TransformerEncoderLayer(72, 3, dropout=0.2)
        self.linear = nn.Sequential(
            nn.Linear(1237, 2048),
            nn.LeakyReLU(),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
        )
        self.criterion = torch.nn.MSELoss()
        self.mae = torchmetrics.MeanAbsoluteError()
        self.mse = torchmetrics.MeanSquaredError()
        


    def encode_input(self,inputs):
        user_id, movie_long_history, movie_history, target_movie_id,  movie_history_long_ratings, movie_history_ratings, target_movie_rating, sex, age_group, occupation = inputs
               
        # print(f"Movie ID Embeddings Weight: {self.embeddings_movie_id.weight.shape}")
        #MOVIES
        transformer_long_features = self.embeddings_movie_id(movie_long_history)
        movie_history = self.embeddings_movie_id(movie_history)
        target_movie = self.embeddings_movie_id(target_movie_id)
        # print(f"Movie History : {movie_history.size()}")
         
        target_movie = torch.unsqueeze(target_movie, 1)
        # print(f"Target Movie: {target_movie.size()}")
        transfomer_features = torch.cat((movie_history, target_movie),dim=1)
        # print(f"transformer_features: {transfomer_features.size()}")

        #USERS
        user_id = self.embeddings_user_id(user_id)
        # print(f"User_id: {user_id.size()}")
        
        sex = self.embeddings_user_sex(sex)
        age_group = self.embeddings_age_group(age_group)
        occupation = self.embeddings_user_occupation(occupation)
        user_features = torch.cat((user_id, sex, age_group,occupation), 1)
        # print(f"User_features: {user_features.size()}")
        
        return transformer_long_features, transfomer_features, user_features, target_movie_rating.float()
    
    def forward(self, batch):
        transformer_long_features, transfomer_features, user_features, target_movie_rating = self.encode_input(batch)
        long_positional_embedding = self.long_positional_embedding(transformer_long_features)
        long_transfomer_features = torch.cat((transformer_long_features, long_positional_embedding), dim=2)
        transformer_long_output = self.long_encoding_layer(long_transfomer_features)
        
        positional_embedding = self.positional_embedding(transfomer_features)
        transfomer_features = torch.cat((transfomer_features, positional_embedding), dim=2)

        transformer_combined_features = torch.cat((transformer_long_output, transfomer_features), dim=1)
        transformer_output = self.transfomerlayer(transformer_combined_features)

        # Combine transformer output with movie_embedding to get item prediction scores
        # Dimensions of transformer_output: (128, 16, 72)
        # Dimensions of transformer_features: (128, 4, 72)
        # Perform dot product between the two
        # Dimensions of dot product: (128, 1)
        dot_product = torch.bmm(transformer_output[:,0,:].unsqueeze(1), transformer_output[:,1:,:].transpose(1,2))
        

        # Convert this to a probability distribution
        item_prediction_scores = torch.softmax(item_prediction_scores, dim=1)
        transformer_output = torch.flatten(transformer_output,start_dim=1)
        # print(f"transformer_features flattened: {transformer_output.size()}")
        
        #Concat with other features
        features = torch.cat((transformer_output,user_features),dim=1)
        # print(f"Input to Linear Layer: {features.size()}")

        output = self.linear(features)
        return output, target_movie_rating
        
    def training_step(self, batch, batch_idx):
        out, target_movie_rating = self(batch)
        out = out.flatten()
        loss = self.criterion(out, target_movie_rating)
        
        mae = self.mae(out, target_movie_rating)
        mse = self.mse(out, target_movie_rating)
        rmse =torch.sqrt(mse)
        self.log(
            "train/mae", mae, on_step=True, on_epoch=False, prog_bar=False
        )
        
        self.log(
            "train/rmse", rmse, on_step=True, on_epoch=False, prog_bar=False
        )
        
        self.log("train/step_loss", loss, on_step=True, on_epoch=False, prog_bar=False)
        return loss
    
    def validation_step(self, batch, batch_idx):
        out, target_movie_rating = self(batch)
        out = out.flatten()
        loss = self.criterion(out, target_movie_rating)
        
        mae = self.mae(out, target_movie_rating)
        mse = self.mse(out, target_movie_rating)
        rmse =torch.sqrt(mse)

        self.log(
            "val/mae", mae, on_step=True, on_epoch=False, prog_bar=False
        )
        
        self.log(
            "val/rmse", rmse, on_step=True, on_epoch=False, prog_bar=False
        )
        
        self.log("val/step_loss", loss, on_step=True, on_epoch=False, prog_bar=False)
        
        return {"val_loss": loss, "mae": mae.detach(), "rmse":rmse.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_mae = torch.stack([x["mae"] for x in outputs]).mean()
        avg_rmse = torch.stack([x["rmse"] for x in outputs]).mean()
        
        self.log("val/loss", avg_loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/mae", avg_mae, on_step=False, on_epoch=True, prog_bar=False)
        self.log("val/rmse", avg_rmse, on_step=False, on_epoch=True, prog_bar=False)


    def test_epoch_end(self, outputs):
        users = torch.cat([x["users"] for x in outputs])
        y_hat = torch.cat([x["top14"] for x in outputs])
        users = users.tolist()
        y_hat = y_hat.tolist()
        
        data = {"users": users, "top14": y_hat}
        df = pd.DataFrame.from_dict(data)
        print(len(df))
        df.to_csv("lightning_logs/predict.csv", index=False)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=0.0005)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--learning_rate", type=float, default=0.01)
        return parser

    ####################
    # DATA RELATED HOOKS
    ####################

    def setup(self, stage=None):
        print("Loading datasets")
        self.train_dataset = LSTRDataset("data/train_data.csv")
        self.val_dataset = LSTRDataset("data/test_data.csv")
        self.test_dataset = LSTRDataset("data/test_data.csv")
        print("Done")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=128,
            shuffle=False,
            num_workers=os.cpu_count(),
        )
        
model = BST()
trainer = pl.Trainer(accelerator="gpu", devices=1 ,max_epochs=10)
trainer.fit(model)

: 

In [None]:
a = torch.randn(128, 16, 72)
b = torch.randn(128, 16, 72)
c = torch.bmm(a[:,0,:].unsqueeze(1), b[:,1:,:].transpose(1,2))

In [None]:
# Convert tensor of shape [128] to [128, 1]


In [None]:
trainer.validate(model)

In [None]:
test_data = pd.read_csv('/content/data/pred_data.csv')

In [None]:
test_data.to_csv

Unnamed: 0,user_id,sequence_movie_ids,sequence_ratings,sex,age_group,occupation
0,1,"1721,1270,1022,2340,1836,3408,1207,2804,260,72...",4553544543544455,0,0,10
1,1,"1836,3408,1207,2804,260,720,1193,919,608,2692,...",5445435444555454,0,0,10
2,1,"720,1193,919,608,2692,1961,2028,3105,938,1035,...",3544455545454543,0,0,10
3,1,"919,608,2692,1961,2028,3105,938,1035,1962,1028...",4445554545454354,0,0,10
4,1,"2692,1961,2028,3105,938,1035,1962,1028,2018,15...",4555454545435444,0,0,10
...,...,...,...,...,...,...
137795,999,"605,22,253,3259,1124,1183,271,1598,3174,3409,2...",2423323244334344,1,2,15
137796,999,"1589,3791,225,507,3173,3176,354,524,280,2447,1...",4443344443124343,1,2,15
137797,999,"2447,1515,724,266,450,2975,371,382,1027,1442,2...",3124343344343434,1,2,15
137798,999,"207,79,2875,2316,2165,361,2688,24,2264,1959,26...",4343133321323342,1,2,15


In [None]:
trainer.save_checkpoint('/content/saved_model', weights_only=True)

In [None]:
predict_dataset = LSTRDataset("data/pred_data.csv")

In [None]:
predict_dataset.__getitem__(1)

(1,
 tensor([1836, 3408, 1207, 2804,  260,  720, 1193,  919,  608, 2692, 1961, 2028]),
 tensor([3105,  938, 1035]),
 1962,
 tensor([5, 4, 4, 5, 4, 3, 5, 4, 4, 4, 5, 5]),
 tensor([5, 4, 5]),
 4,
 0,
 0,
 10)

In [None]:
data = torch.utils.data.DataLoader(
            predict_dataset,
            batch_size=8,
            shuffle=False,
            num_workers=os.cpu_count(),
        )

In [None]:
batch = next(iter(data))

In [None]:
sample_prediction = model(batch)

In [None]:
sample_prediction

(tensor([[4.3132],
         [4.0831],
         [4.2777],
         [4.1192],
         [4.2314],
         [4.0133],
         [4.1998],
         [3.9220]], grad_fn=<AddmmBackward0>),
 tensor([5., 4., 3., 4., 4., 4., 3., 5.]))

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/version_4