In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Input, Dense
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, Dense
import keras_nlp

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import copy
warnings.simplefilter(action='ignore')

In [4]:
# Load the MovieLens 100K dataset
# Assuming the file is 'u.data' in the same directory
# Columns: user_id, movie_id, rating, timestamp
data = pd.read_csv('u.data', sep='\t', names=['user_id', 'movie_id', 'rating', 'timestamp'])

# Step 0: only select top 50 movies
valid_movies = np.array(data['movie_id'].value_counts().index[:50])
data = data[data['movie_id'].isin(valid_movies)]

# Step 1: Remove users s.t. count >= 2 * unique timestamp
temp = data.groupby('user_id')['timestamp'].agg(['count', 'nunique']).reset_index()
valid_users = temp[temp['count']/temp['nunique'] < 2]['user_id']
data = data[data['user_id'].isin(valid_users)]

# Step 2: Deduplicate and recode
data = data.groupby(['user_id', 'timestamp'], group_keys=False).apply(lambda group: group.sample(n=1, random_state=42))
data = data.sort_values(['user_id', 'timestamp']).reset_index(drop = True)
data['movie_id'] = pd.factorize(data['movie_id'])[0] + 1

num_users = data['user_id'].nunique()
num_movies = data['movie_id'].nunique()
num_ratings = len(data)

sparsity = 1 - (num_ratings / (num_users * num_movies))

print(f"Filtered Dataset: {num_users} users, {num_movies} movies")
print(f"Sparsity: {sparsity:.2%}")

Filtered Dataset: 902 users, 50 movies
Sparsity: 69.50%


In [5]:
data

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1,5,874965478
1,1,2,5,874965706
2,1,3,5,874965739
3,1,4,5,874965758
4,1,5,5,874965954
...,...,...,...,...
13752,943,23,4,888639407
13753,943,13,5,888639427
13754,943,8,4,888692413
13755,943,41,4,888692699


In [6]:
np.random.seed(42)

all_users = data['user_id'].unique()
shuffled_indices = np.random.permutation(len(all_users))

midpoint = 3 * len(all_users) // 4
indices_1 = shuffled_indices[:midpoint]
indices_2 = shuffled_indices[midpoint:]

train_users = shuffled_indices[indices_1]
test_users = shuffled_indices[indices_2]

In [7]:
data_train = data[data['user_id'].isin(train_users)].reset_index(drop = True)
data_test = data[data['user_id'].isin(test_users)].reset_index(drop = True)

In [8]:
# Prepare baskets
def prepare_baskets(data):
    return data.groupby("user_id")["movie_id"].apply(list).tolist()

train_baskets = prepare_baskets(data_train)
test_baskets = prepare_baskets(data_test)

train_baskets, val_baskets = train_test_split(train_baskets, test_size=0.25, random_state=42)

max_len = max(len(basket) for basket in train_baskets + val_baskets + test_baskets)

In [9]:
max(len(basket) for basket in test_baskets)

43

In [10]:
D = 32
batch_size = 256
max_epochs = 2000
lr = 1e-4
max_items = len(set(data_test['movie_id']))

In [11]:
len(set(data_test['movie_id']))

50

In [12]:
def preprocess_baskets(baskets):
    context_inputs = []
    target_inputs = []
    masked_idxs = []

    for basket in baskets:
        for idx, elt in enumerate(basket):
            target_inputs.append(elt)
            context_inputs.append(basket[:idx] + [max_items + 1] + basket[(idx+1):])
            masked_idxs.append(idx)

    context_inputs = pad_sequences(context_inputs, padding='post', maxlen = max_len, value=0)
    return np.array(context_inputs), np.array(target_inputs) - 1, np.array(masked_idxs)

train_context_input, train_target_input, train_masked_idxs = preprocess_baskets(train_baskets)
val_context_input, val_target_input, val_masked_idxs = preprocess_baskets(val_baskets)
test_context_input, test_target_input, test_masked_idxs = preprocess_baskets(test_baskets)

In [13]:
test_context_input[0]

array([51, 37,  2,  5, 49, 21, 18,  6, 28,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=int32)

In [14]:
input_context = layers.Input(shape=(max_len,), dtype=tf.int32, name="context_input")
masked_idx_input = layers.Input(shape=(1,), dtype=tf.int32, name="masked_idx_input")

alpha_embedding = layers.Embedding(input_dim=max_items + 2, output_dim=D, name="alpha_embedding")
context_embedding = alpha_embedding(input_context) 

class ZeroMaskEmbedding(layers.Layer):
    def call(self, embeddings, input_tokens):
        mask = tf.cast(tf.not_equal(input_tokens, 0), tf.float32) 
        mask = tf.expand_dims(mask, axis=-1) 
        return embeddings * mask 

context_embedding = ZeroMaskEmbedding()(context_embedding, input_context)

class MaskLayer(layers.Layer):
    def call(self, input_context, position):
        return position * tf.expand_dims(tf.cast(tf.not_equal(input_context, 0), tf.float32), axis = -1)

position = keras_nlp.layers.PositionEmbedding(sequence_length=max_len)(context_embedding)
masked_position = MaskLayer()(input_context, position)
context_embedding = context_embedding + masked_position

class CreateAttentionMask(layers.Layer):
    def call(self, inputs):
        input_context = inputs
        temp = tf.cast(tf.not_equal(input_context, 0), dtype=tf.float32)
        return tf.expand_dims(tf.expand_dims(temp, axis=1), axis=1)

attention_mask = CreateAttentionMask()(input_context)

attention_layer_1 = MultiHeadAttention(num_heads=2, key_dim=16, name="multi_head_attention_1")
attn_output_1 = attention_layer_1(
    query=context_embedding,
    value=context_embedding,
    key=context_embedding,
    attention_mask=attention_mask
)

attn_output_1 = context_embedding + attn_output_1

attention_layer_2 = MultiHeadAttention(num_heads=2, key_dim=16, name="multi_head_attention_2")
attn_output_2 = attention_layer_2(
    query=attn_output_1,
    value=attn_output_1,
    key=attn_output_1,
    attention_mask=attention_mask
)

context_embedding = attn_output_1 + attn_output_2

class GatherLayer(layers.Layer):
    def call(self, inputs):
        context_embedding, masked_idx_input = inputs
        return tf.gather(context_embedding, indices=tf.squeeze(masked_idx_input, axis=-1), batch_dims=1)

masked_embeddings = GatherLayer()([context_embedding, masked_idx_input])

output = layers.Dense(max_items, activation="softmax", name="output_layer", use_bias = False)(masked_embeddings)

model = Model(inputs=[input_context, masked_idx_input], outputs=output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss="sparse_categorical_crossentropy")

early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

history = model.fit(
    [train_context_input, train_masked_idxs], train_target_input,
    validation_data=([val_context_input, val_masked_idxs], val_target_input),
    batch_size=batch_size,
    epochs=max_epochs,
    callbacks=[early_stopping]
)

Epoch 1/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 27ms/step - loss: 3.9286 - val_loss: 3.9236
Epoch 2/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9226 - val_loss: 3.9212
Epoch 3/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9190 - val_loss: 3.9190
Epoch 4/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9166 - val_loss: 3.9168
Epoch 5/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9167 - val_loss: 3.9146
Epoch 6/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9109 - val_loss: 3.9125
Epoch 7/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9065 - val_loss: 3.9104
Epoch 8/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.9075 - val_loss: 3.9081
Epoch 9/2000
[1m29/29[0m [32m

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5786 - val_loss: 3.6067
Epoch 69/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5636 - val_loss: 3.6041
Epoch 70/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5608 - val_loss: 3.6024
Epoch 71/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 21ms/step - loss: 3.5649 - val_loss: 3.6000
Epoch 72/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5690 - val_loss: 3.5984
Epoch 73/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5522 - val_loss: 3.5966
Epoch 74/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5494 - val_loss: 3.5949
Epoch 75/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.5675 - val_loss: 3.5933
Epoch 76/2000
[1m29/29[0m [32m━━━━━

Epoch 135/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4799 - val_loss: 3.5353
Epoch 136/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4903 - val_loss: 3.5360
Epoch 137/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4702 - val_loss: 3.5341
Epoch 138/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4731 - val_loss: 3.5340
Epoch 139/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4579 - val_loss: 3.5332
Epoch 140/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4589 - val_loss: 3.5335
Epoch 141/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4765 - val_loss: 3.5327
Epoch 142/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4705 - val_loss: 3.5327
Epoch 143/2000


[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4153 - val_loss: 3.5125
Epoch 202/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4179 - val_loss: 3.5124
Epoch 203/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4216 - val_loss: 3.5113
Epoch 204/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4162 - val_loss: 3.5109
Epoch 205/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4458 - val_loss: 3.5113
Epoch 206/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4100 - val_loss: 3.5107
Epoch 207/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4149 - val_loss: 3.5096
Epoch 208/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.4172 - val_loss: 3.5109
Epoch 209/2000
[1m29/29[0m [

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 21ms/step - loss: 3.3858 - val_loss: 3.4976
Epoch 268/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3927 - val_loss: 3.4994
Epoch 269/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3775 - val_loss: 3.4984
Epoch 270/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3779 - val_loss: 3.4980
Epoch 271/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3887 - val_loss: 3.4983
Epoch 272/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3840 - val_loss: 3.4973
Epoch 273/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3919 - val_loss: 3.4971
Epoch 274/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3826 - val_loss: 3.4967
Epoch 275/2000
[1m29/29[0m [

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3436 - val_loss: 3.4911
Epoch 334/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3574 - val_loss: 3.4913
Epoch 335/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3552 - val_loss: 3.4918
Epoch 336/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3473 - val_loss: 3.4915
Epoch 337/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3341 - val_loss: 3.4914
Epoch 338/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3456 - val_loss: 3.4911
Epoch 339/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3430 - val_loss: 3.4905
Epoch 340/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3623 - val_loss: 3.4912
Epoch 341/2000
[1m29/29[0m [

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3250 - val_loss: 3.4871
Epoch 400/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3253 - val_loss: 3.4866
Epoch 401/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3246 - val_loss: 3.4869
Epoch 402/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3296 - val_loss: 3.4862
Epoch 403/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3253 - val_loss: 3.4875
Epoch 404/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3252 - val_loss: 3.4869
Epoch 405/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3364 - val_loss: 3.4863
Epoch 406/2000
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - loss: 3.3062 - val_loss: 3.4870
Epoch 407/2000
[1m29/29[0m [

In [15]:
model.summary()

In [16]:
# Evaluate on Test Data
test_loss = model.evaluate([test_context_input, test_masked_idxs], test_target_input, batch_size=batch_size)
print(f"Test Loss: {test_loss}")

[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - loss: 3.4869
Test Loss: 3.482757091522217
