In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import copy
warnings.simplefilter(action='ignore')

In [2]:
data_train_final = pd.read_csv('data_train_final.csv')
data_test_final = pd.read_csv('data_test_final.csv')

In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Input, Dense
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, Dense
import keras_nlp

In [5]:
# Prepare baskets
def prepare_baskets(data):
    return data.groupby("order_id")["product_id"].apply(list).tolist()

train_baskets = prepare_baskets(data_train_final)
test_baskets = prepare_baskets(data_test_final)

train_baskets, val_baskets = train_test_split(train_baskets, test_size=0.25, random_state=42)

max_len = max(len(basket) for basket in train_baskets + val_baskets + test_baskets)

In [6]:
D = 32
batch_size = 256
max_epochs = 1000
lr = 1e-4
max_items = len(set(data_train_final['product_id']))

In [7]:
def preprocess_baskets(baskets):
    context_inputs = []
    target_inputs = []
    
    for basket in baskets:
        for idx, elt in enumerate(basket):            
            target_inputs.append(elt)
            context_inputs.append(basket[:idx] + [max_items + 1])
    
    context_inputs = pad_sequences(context_inputs, padding='post', maxlen = max_len, value=0)
    return np.array(context_inputs), np.array(target_inputs) - 1

train_context_input, train_target_input = preprocess_baskets(train_baskets)
val_context_input, val_target_input = preprocess_baskets(val_baskets)
test_context_input, test_target_input = preprocess_baskets(test_baskets)

In [13]:
train_context_input[3]

array([55, 54, 12, 64,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0], dtype=int32)

In [14]:
train_target_input[3]

19

In [15]:
input_context = layers.Input(shape=(max_len,), dtype=tf.int32, name="context_input")

alpha_embedding = layers.Embedding(input_dim=max_items + 2, output_dim=D, name="alpha_embedding")
context_embedding = alpha_embedding(input_context) 

class ZeroMaskEmbedding(layers.Layer):
    def call(self, embeddings, input_tokens):
        mask = tf.cast(tf.not_equal(input_tokens, 0), tf.float32) 
        mask = tf.expand_dims(mask, axis=-1) 
        return embeddings * mask 

context_embedding = ZeroMaskEmbedding()(context_embedding, input_context)

class SumLayer(layers.Layer):
    def call(self, inputs):
        return tf.reduce_sum(inputs, axis=1) 

masked_embeddings = SumLayer()(context_embedding)

output = layers.Dense(max_items, activation="softmax", name="output_layer", use_bias = False)(masked_embeddings)

model = Model(inputs=input_context, outputs=output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss="sparse_categorical_crossentropy")

early_stopping = EarlyStopping(monitor='val_loss', patience=1000, restore_best_weights=True)

history = model.fit(
    train_context_input, train_target_input,        
    validation_data=(val_context_input, val_target_input),
    batch_size=batch_size,
    epochs=max_epochs,
    callbacks=[early_stopping]
)

Epoch 1/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 4.1111 - val_loss: 4.0153
Epoch 2/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 4.0041 - val_loss: 3.9703
Epoch 3/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.9607 - val_loss: 3.9302
Epoch 4/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.9229 - val_loss: 3.8967
Epoch 5/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.8944 - val_loss: 3.8734
Epoch 6/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.8722 - val_loss: 3.8573
Epoch 7/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.8551 - val_loss: 3.8450
Epoch 8/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.8424 - val_loss: 3.8348
Epoch 9/1000
[1m807/807

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7272 - val_loss: 3.7363
Epoch 68/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - loss: 3.7292 - val_loss: 3.7361
Epoch 69/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7286 - val_loss: 3.7358
Epoch 70/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7249 - val_loss: 3.7356
Epoch 71/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7287 - val_loss: 3.7354
Epoch 72/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7260 - val_loss: 3.7351
Epoch 73/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - loss: 3.7283 - val_loss: 3.7349
Epoch 74/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7230 - val_loss: 3.7347
Epoch 75/1000
[1m807/807[0m 

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7216 - val_loss: 3.7287
Epoch 134/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7195 - val_loss: 3.7287
Epoch 135/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7215 - val_loss: 3.7286
Epoch 136/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7217 - val_loss: 3.7286
Epoch 137/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7201 - val_loss: 3.7285
Epoch 138/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7207 - val_loss: 3.7285
Epoch 139/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7204 - val_loss: 3.7284
Epoch 140/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7213 - val_loss: 3.7284
Epoch 141/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7164 - val_loss: 3.7263
Epoch 200/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7160 - val_loss: 3.7263
Epoch 201/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7187 - val_loss: 3.7262
Epoch 202/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7168 - val_loss: 3.7263
Epoch 203/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7164 - val_loss: 3.7262
Epoch 204/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7167 - val_loss: 3.7261
Epoch 205/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7179 - val_loss: 3.7262
Epoch 206/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - loss: 3.7179 - val_loss: 3.7261
Epoch 207/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7134 - val_loss: 3.7249
Epoch 266/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7139 - val_loss: 3.7249
Epoch 267/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7150 - val_loss: 3.7249
Epoch 268/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7149 - val_loss: 3.7249
Epoch 269/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.7134 - val_loss: 3.7249
Epoch 270/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7198 - val_loss: 3.7248
Epoch 271/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7165 - val_loss: 3.7248
Epoch 272/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7143 - val_loss: 3.7248
Epoch 273/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7153 - val_loss: 3.7240
Epoch 332/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7137 - val_loss: 3.7239
Epoch 333/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7158 - val_loss: 3.7239
Epoch 334/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7150 - val_loss: 3.7239
Epoch 335/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7139 - val_loss: 3.7238
Epoch 336/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7156 - val_loss: 3.7238
Epoch 337/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7166 - val_loss: 3.7238
Epoch 338/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7161 - val_loss: 3.7238
Epoch 339/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7189 - val_loss: 3.7229
Epoch 398/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7155 - val_loss: 3.7229
Epoch 399/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 3ms/step - loss: 3.7152 - val_loss: 3.7229
Epoch 400/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7150 - val_loss: 3.7229
Epoch 401/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7119 - val_loss: 3.7228
Epoch 402/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7166 - val_loss: 3.7228
Epoch 403/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7166 - val_loss: 3.7228
Epoch 404/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7157 - val_loss: 3.7228
Epoch 405/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7139 - val_loss: 3.7221
Epoch 464/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7155 - val_loss: 3.7222
Epoch 465/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7149 - val_loss: 3.7221
Epoch 466/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7129 - val_loss: 3.7221
Epoch 467/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7174 - val_loss: 3.7221
Epoch 468/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7130 - val_loss: 3.7222
Epoch 469/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.7145 - val_loss: 3.7221
Epoch 470/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7142 - val_loss: 3.7221
Epoch 471/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7125 - val_loss: 3.7217
Epoch 530/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7119 - val_loss: 3.7217
Epoch 531/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7150 - val_loss: 3.7216
Epoch 532/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7171 - val_loss: 3.7216
Epoch 533/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7116 - val_loss: 3.7216
Epoch 534/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7145 - val_loss: 3.7216
Epoch 535/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7175 - val_loss: 3.7216
Epoch 536/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7140 - val_loss: 3.7216
Epoch 537/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7137 - val_loss: 3.7212
Epoch 596/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7120 - val_loss: 3.7212
Epoch 597/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7155 - val_loss: 3.7212
Epoch 598/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7113 - val_loss: 3.7212
Epoch 599/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7147 - val_loss: 3.7212
Epoch 600/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7136 - val_loss: 3.7213
Epoch 601/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7147 - val_loss: 3.7212
Epoch 602/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7124 - val_loss: 3.7212
Epoch 603/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7128 - val_loss: 3.7209
Epoch 662/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7133 - val_loss: 3.7209
Epoch 663/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7133 - val_loss: 3.7209
Epoch 664/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7108 - val_loss: 3.7210
Epoch 665/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7130 - val_loss: 3.7210
Epoch 666/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7105 - val_loss: 3.7209
Epoch 667/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7105 - val_loss: 3.7210
Epoch 668/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7127 - val_loss: 3.7211
Epoch 669/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7136 - val_loss: 3.7208
Epoch 728/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7137 - val_loss: 3.7209
Epoch 729/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.7155 - val_loss: 3.7208
Epoch 730/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7126 - val_loss: 3.7208
Epoch 731/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7143 - val_loss: 3.7208
Epoch 732/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7131 - val_loss: 3.7209
Epoch 733/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7136 - val_loss: 3.7208
Epoch 734/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7094 - val_loss: 3.7208
Epoch 735/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7120 - val_loss: 3.7206
Epoch 794/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7117 - val_loss: 3.7207
Epoch 795/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7115 - val_loss: 3.7207
Epoch 796/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - loss: 3.7146 - val_loss: 3.7207
Epoch 797/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7119 - val_loss: 3.7206
Epoch 798/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7124 - val_loss: 3.7206
Epoch 799/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - loss: 3.7127 - val_loss: 3.7207
Epoch 800/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7109 - val_loss: 3.7206
Epoch 801/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7136 - val_loss: 3.7205
Epoch 860/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7125 - val_loss: 3.7205
Epoch 861/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7168 - val_loss: 3.7205
Epoch 862/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7148 - val_loss: 3.7205
Epoch 863/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7143 - val_loss: 3.7205
Epoch 864/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7103 - val_loss: 3.7205
Epoch 865/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7129 - val_loss: 3.7204
Epoch 866/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7099 - val_loss: 3.7205
Epoch 867/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7121 - val_loss: 3.7203
Epoch 926/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7105 - val_loss: 3.7204
Epoch 927/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7140 - val_loss: 3.7203
Epoch 928/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7096 - val_loss: 3.7204
Epoch 929/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7125 - val_loss: 3.7203
Epoch 930/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7146 - val_loss: 3.7203
Epoch 931/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7119 - val_loss: 3.7204
Epoch 932/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7144 - val_loss: 3.7204
Epoch 933/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7101 - val_loss: 3.7202
Epoch 992/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7117 - val_loss: 3.7202
Epoch 993/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.7115 - val_loss: 3.7202
Epoch 994/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7140 - val_loss: 3.7202
Epoch 995/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 4ms/step - loss: 3.7123 - val_loss: 3.7202
Epoch 996/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.7129 - val_loss: 3.7203
Epoch 997/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.7121 - val_loss: 3.7203
Epoch 998/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 3.7118 - val_loss: 3.7202
Epoch 999/1000
[1m807/

In [16]:
model.summary()

In [17]:
# Evaluate on Test Data
test_loss = model.evaluate(test_context_input, test_target_input, batch_size=batch_size)
print(f"Test Loss: {test_loss}")

[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 3.7128
Test Loss: 3.715456485748291


In [23]:
alpha_embedding_layer = model.get_layer("alpha_embedding")
alpha_embedding_weights = alpha_embedding_layer.get_weights()[0][1:-1]

In [24]:
alpha_embedding_weights.shape

(63, 32)

In [25]:
output_layer = model.get_layer("output_layer")
output_layer_weights = output_layer.get_weights()[0]

In [26]:
output_layer_weights.shape

(32, 63)

In [27]:
sim_matrix = pd.DataFrame(np.matmul(alpha_embedding_weights, output_layer_weights) + \
    np.matmul(output_layer_weights.T, alpha_embedding_weights.T))
sim_matrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,53,54,55,56,57,58,59,60,61,62
0,-14.117208,0.275404,0.644090,0.191577,0.514784,0.086332,0.253121,0.312386,0.481466,0.076147,...,0.018634,0.022815,0.290774,0.186797,0.455683,0.057822,0.130421,0.415062,-0.095735,0.153897
1,0.275404,-7.645995,0.738901,0.336778,0.171940,0.369585,0.079820,0.486054,0.342331,0.093342,...,0.038329,0.027772,-0.299224,-0.051832,0.138334,-0.289801,0.693116,0.365656,-0.351948,-0.596177
2,0.644090,0.738901,-10.244969,-0.014249,0.023768,0.226181,0.574340,0.594795,0.449021,0.471653,...,0.146462,-0.214105,-0.015649,0.252668,0.290005,-0.128685,0.368667,0.618434,-0.482750,0.008114
3,0.191577,0.336778,-0.014249,-10.121716,-0.152936,0.022510,0.202117,0.399451,0.252676,0.080621,...,0.578954,0.219032,-0.844487,-0.217740,0.313501,-0.063929,-0.182685,0.015575,-0.009273,-0.356413
4,0.514784,0.171940,0.023768,-0.152936,-0.570143,-0.011558,0.355079,0.298840,0.078561,-0.052688,...,0.037778,-0.649993,-0.244643,0.002512,-0.103974,0.219028,0.189867,0.111708,0.260134,0.174255
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
58,0.057822,-0.289801,-0.128685,-0.063929,0.219028,0.071868,0.163877,0.013608,0.105369,-0.101848,...,-3.174605,1.174975,0.248187,0.316648,-0.505899,-2.924871,0.317926,-0.489723,0.168180,0.473405
59,0.130421,0.693116,0.368667,-0.182685,0.189867,0.428426,0.572614,0.524153,0.243254,0.287745,...,0.222735,0.296259,0.598299,0.142787,-0.114281,0.317926,-10.812643,-0.247271,0.378463,0.181202
60,0.415062,0.365656,0.618434,0.015575,0.111708,-0.082365,-0.097845,0.416060,-0.422426,-0.289737,...,-0.035537,0.047743,-0.465540,-0.105159,-0.245749,-0.489723,-0.247271,-0.775956,0.079978,-0.766348
61,-0.095735,-0.351948,-0.482750,-0.009273,0.260134,-0.151748,-0.273854,-0.039525,-1.185645,0.268111,...,0.256391,-0.139810,0.277479,0.198102,0.219529,0.168180,0.378463,0.079978,-1.066577,0.314004


In [28]:
top_5_indices_desc = np.argsort(sim_matrix, axis=1)[:, -5:][:, ::-1]

top_5_dict = {}
for i in range(top_5_indices_desc.shape[0]):
    top_5_dict[i] = list(top_5_indices_desc[i])

In [29]:
products = pd.read_csv('products.csv')
products_dict = {}
for i in range(products.shape[0]):
    products_dict[products['product_id'][i]] = products['product_name'][i]

In [30]:
product_ids = [21903, 30391, 46667, 13176, 21616,  8518, 22935,  5876, 48679,
       24838, 31717, 47209, 26209, 34969, 27966, 37646, 44632, 16797,
       39275,  5077, 10749, 49235, 21137, 28204, 21938, 46979, 47626,
       44359, 34126, 28985, 24852, 41950, 30489,  9076, 24964, 45007,
       42265, 49683, 47766, 39877, 19057, 40706,  5450, 43961, 39928,
       22825, 12341, 17794,  4605, 22035, 27845, 27104, 26604,  8277,
        4920, 25890, 31506, 35951, 45066, 24184, 19660, 27086, 43352]

all_products = []
for i in product_ids:
    all_products.append(products_dict[i])

In [31]:
top_5_dict_items = {}

for k, v in top_5_dict.items():
    key = all_products[k]
    value = [all_products[val] for val in v]
    
    top_5_dict_items[key] = value

In [32]:
top_5_dict_items

{'Organic Baby Spinach': ['Organic Ginger Root',
  'Organic Zucchini',
  'Organic Avocado',
  'Organic Grape Tomatoes',
  'Organic Baby Arugula'],
 'Organic Cucumber': ['Organic Small Bunch Celery',
  'Organic Hass Avocado',
  'Organic Ginger Root',
  'Red Peppers',
  'Organic Large Extra Fancy Fuji Apple'],
 'Organic Ginger Root': ['Organic Garlic',
  'Organic Cilantro',
  'Organic Cucumber',
  'Limes',
  'Organic Italian Parsley Bunch'],
 'Bag of Organic Bananas': ['Organic Hass Avocado',
  'Hass Avocados',
  'Organic Large Extra Fancy Fuji Apple',
  'Organic Raspberries',
  "Organic D'Anjou Pears"],
 'Organic Baby Arugula': ['Organic Grape Tomatoes',
  'Organic Baby Spinach',
  'Large Lemon',
  'Organic Italian Parsley Bunch',
  'Hass Avocados'],
 'Organic Red Onion': ['Organic Cilantro',
  'Organic Red Bell Pepper',
  'Limes',
  'Small Hass Avocado',
  'Green Bell Pepper'],
 'Organic Yellow Onion': ['Organic Garlic',
  'Organic Garnet Sweet Potato (Yam)',
  'Organic Italian Parsley