In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import copy
warnings.simplefilter(action='ignore')

In [2]:
data_train_final = pd.read_csv('data_train_final.csv')
data_test_final = pd.read_csv('data_test_final.csv')

In [4]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Input, Dense
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [5]:
# Prepare baskets
def prepare_baskets(data):
    return data.groupby("order_id")["product_id"].apply(list).tolist()

train_baskets = prepare_baskets(data_train_final)
test_baskets = prepare_baskets(data_test_final)

train_baskets, val_baskets = train_test_split(train_baskets, test_size=0.25, random_state=42)

max_len = max(len(basket) for basket in train_baskets + val_baskets + test_baskets)

In [6]:
D = 32
batch_size = 256
max_epochs = 1000
lr = 1e-4
max_items = len(set(data_train_final['product_id']))

In [7]:
def preprocess_baskets(baskets):
    context_inputs = []
    target_inputs = []
    
    for basket in baskets:
        for idx, elt in enumerate(basket):
            target_inputs.append(elt)
            context_inputs.append(basket[:idx] + basket[(idx+1):])
    
    context_inputs = pad_sequences(context_inputs, padding='post', maxlen = max_len - 1, value=0)
    return np.array(context_inputs), np.array(target_inputs) - 1

train_context_input, train_target_input = preprocess_baskets(train_baskets)
val_context_input, val_target_input = preprocess_baskets(val_baskets)
test_context_input, test_target_input = preprocess_baskets(test_baskets)

In [8]:
train_context_input[2]

array([55, 54, 20,  4,  1, 29, 21, 11,  8, 23,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0], dtype=int32)

In [9]:
train_target_input[2]

11

In [10]:
input_context = layers.Input(shape=(max_len - 1,), dtype=tf.int32, name="context_input")

alpha_embedding = layers.Embedding(input_dim=max_items + 1, output_dim=D, name="alpha_embedding")
context_embedding = alpha_embedding(input_context)  # Shape: (batch_size, max_len, embedding_dim)

class ZeroMaskEmbedding(layers.Layer):
    def call(self, embeddings, input_tokens):
        mask = tf.cast(tf.not_equal(input_tokens, 0), tf.float32) 
        mask = tf.expand_dims(mask, axis=-1) 
        return embeddings * mask 

context_embedding = ZeroMaskEmbedding()(context_embedding, input_context)

class SumLayer(layers.Layer):
    def call(self, inputs):
        return tf.reduce_sum(inputs, axis=1) 

summed_context = SumLayer()(context_embedding)

output = layers.Dense(max_items, activation="softmax", name="output_layer", use_bias = False)(summed_context)

model = Model(inputs=input_context, outputs=output)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss="sparse_categorical_crossentropy")

early_stopping = EarlyStopping(monitor='val_loss', patience=1000, restore_best_weights=True)

history = model.fit(
    train_context_input, train_target_input,        
    validation_data=(val_context_input, val_target_input),
    batch_size=batch_size,
    epochs=max_epochs,
    callbacks=[early_stopping]
)

Epoch 1/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 4.0903 - val_loss: 3.9669
Epoch 2/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.9537 - val_loss: 3.9063
Epoch 3/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.8948 - val_loss: 3.8501
Epoch 4/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.8377 - val_loss: 3.7998
Epoch 5/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7917 - val_loss: 3.7582
Epoch 6/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7541 - val_loss: 3.7243
Epoch 7/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.7210 - val_loss: 3.6966
Epoch 8/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.6932 - val_loss: 3.6735
Epoch 9/1000
[1m807/807

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4909 - val_loss: 3.4979
Epoch 68/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4881 - val_loss: 3.4978
Epoch 69/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4871 - val_loss: 3.4976
Epoch 70/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4865 - val_loss: 3.4972
Epoch 71/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4890 - val_loss: 3.4970
Epoch 72/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4870 - val_loss: 3.4967
Epoch 73/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4878 - val_loss: 3.4965
Epoch 74/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4851 - val_loss: 3.4963
Epoch 75/1000
[1m807/807[0m 

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4789 - val_loss: 3.4899
Epoch 134/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4782 - val_loss: 3.4900
Epoch 135/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4812 - val_loss: 3.4898
Epoch 136/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4777 - val_loss: 3.4897
Epoch 137/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4778 - val_loss: 3.4896
Epoch 138/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4792 - val_loss: 3.4896
Epoch 139/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4784 - val_loss: 3.4895
Epoch 140/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4778 - val_loss: 3.4896
Epoch 141/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4764 - val_loss: 3.4858
Epoch 200/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4725 - val_loss: 3.4858
Epoch 201/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4745 - val_loss: 3.4856
Epoch 202/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4728 - val_loss: 3.4857
Epoch 203/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4744 - val_loss: 3.4855
Epoch 204/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4716 - val_loss: 3.4855
Epoch 205/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4729 - val_loss: 3.4855
Epoch 206/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4732 - val_loss: 3.4854
Epoch 207/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4704 - val_loss: 3.4824
Epoch 266/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4719 - val_loss: 3.4822
Epoch 267/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4751 - val_loss: 3.4823
Epoch 268/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4689 - val_loss: 3.4822
Epoch 269/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4685 - val_loss: 3.4822
Epoch 270/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4763 - val_loss: 3.4821
Epoch 271/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4688 - val_loss: 3.4822
Epoch 272/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4731 - val_loss: 3.4821
Epoch 273/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4695 - val_loss: 3.4798
Epoch 332/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4728 - val_loss: 3.4798
Epoch 333/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4690 - val_loss: 3.4799
Epoch 334/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4713 - val_loss: 3.4797
Epoch 335/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4667 - val_loss: 3.4797
Epoch 336/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4688 - val_loss: 3.4797
Epoch 337/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4726 - val_loss: 3.4796
Epoch 338/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4670 - val_loss: 3.4798
Epoch 339/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4707 - val_loss: 3.4784
Epoch 398/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4662 - val_loss: 3.4784
Epoch 399/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4718 - val_loss: 3.4785
Epoch 400/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4683 - val_loss: 3.4785
Epoch 401/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4674 - val_loss: 3.4784
Epoch 402/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4715 - val_loss: 3.4783
Epoch 403/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4714 - val_loss: 3.4782
Epoch 404/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4654 - val_loss: 3.4784
Epoch 405/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4706 - val_loss: 3.4775
Epoch 464/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4682 - val_loss: 3.4775
Epoch 465/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4687 - val_loss: 3.4775
Epoch 466/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4693 - val_loss: 3.4775
Epoch 467/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4706 - val_loss: 3.4775
Epoch 468/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4657 - val_loss: 3.4776
Epoch 469/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4708 - val_loss: 3.4774
Epoch 470/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4707 - val_loss: 3.4774
Epoch 471/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4697 - val_loss: 3.4770
Epoch 530/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4689 - val_loss: 3.4770
Epoch 531/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4673 - val_loss: 3.4771
Epoch 532/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4673 - val_loss: 3.4771
Epoch 533/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - loss: 3.4704 - val_loss: 3.4772
Epoch 534/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4678 - val_loss: 3.4771
Epoch 535/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4701 - val_loss: 3.4770
Epoch 536/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4684 - val_loss: 3.4771
Epoch 537/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4655 - val_loss: 3.4768
Epoch 596/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4639 - val_loss: 3.4768
Epoch 597/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4698 - val_loss: 3.4768
Epoch 598/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4705 - val_loss: 3.4768
Epoch 599/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4648 - val_loss: 3.4768
Epoch 600/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4669 - val_loss: 3.4768
Epoch 601/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4653 - val_loss: 3.4767
Epoch 602/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4664 - val_loss: 3.4767
Epoch 603/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4674 - val_loss: 3.4765
Epoch 662/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4698 - val_loss: 3.4766
Epoch 663/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4678 - val_loss: 3.4766
Epoch 664/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4712 - val_loss: 3.4767
Epoch 665/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4621 - val_loss: 3.4766
Epoch 666/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4699 - val_loss: 3.4767
Epoch 667/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4673 - val_loss: 3.4766
Epoch 668/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4608 - val_loss: 3.4766
Epoch 669/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4681 - val_loss: 3.4765
Epoch 728/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4702 - val_loss: 3.4765
Epoch 729/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4668 - val_loss: 3.4766
Epoch 730/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4668 - val_loss: 3.4765
Epoch 731/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4675 - val_loss: 3.4764
Epoch 732/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4640 - val_loss: 3.4765
Epoch 733/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4688 - val_loss: 3.4765
Epoch 734/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4657 - val_loss: 3.4765
Epoch 735/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4652 - val_loss: 3.4763
Epoch 794/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4653 - val_loss: 3.4764
Epoch 795/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4669 - val_loss: 3.4763
Epoch 796/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4681 - val_loss: 3.4765
Epoch 797/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4683 - val_loss: 3.4764
Epoch 798/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4666 - val_loss: 3.4763
Epoch 799/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4682 - val_loss: 3.4764
Epoch 800/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 3.4649 - val_loss: 3.4764
Epoch 801/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4675 - val_loss: 3.4763
Epoch 860/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4688 - val_loss: 3.4764
Epoch 861/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4688 - val_loss: 3.4764
Epoch 862/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4667 - val_loss: 3.4763
Epoch 863/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4641 - val_loss: 3.4764
Epoch 864/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4665 - val_loss: 3.4763
Epoch 865/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4698 - val_loss: 3.4764
Epoch 866/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4655 - val_loss: 3.4762
Epoch 867/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4693 - val_loss: 3.4762
Epoch 926/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4696 - val_loss: 3.4763
Epoch 927/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4688 - val_loss: 3.4762
Epoch 928/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4716 - val_loss: 3.4763
Epoch 929/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4690 - val_loss: 3.4764
Epoch 930/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4667 - val_loss: 3.4763
Epoch 931/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4686 - val_loss: 3.4763
Epoch 932/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4671 - val_loss: 3.4762
Epoch 933/1000
[1m807/

[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4678 - val_loss: 3.4762
Epoch 992/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4709 - val_loss: 3.4762
Epoch 993/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4691 - val_loss: 3.4763
Epoch 994/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4637 - val_loss: 3.4763
Epoch 995/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - loss: 3.4665 - val_loss: 3.4762
Epoch 996/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4656 - val_loss: 3.4763
Epoch 997/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - loss: 3.4683 - val_loss: 3.4762
Epoch 998/1000
[1m807/807[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 3.4606 - val_loss: 3.4762
Epoch 999/1000
[1m807/

In [11]:
model.summary()

In [12]:
# Evaluate on Test Data
test_loss = model.evaluate(test_context_input, test_target_input, batch_size=batch_size)
print(f"Test Loss: {test_loss}")

[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 3.4709
Test Loss: 3.475806713104248


In [20]:
alpha_embedding_layer = model.get_layer("alpha_embedding")
alpha_embedding_weights = alpha_embedding_layer.get_weights()[0][1:]

In [21]:
alpha_embedding_weights.shape

(63, 32)

In [22]:
output_layer = model.get_layer("output_layer")
output_layer_weights = output_layer.get_weights()[0]

In [23]:
output_layer_weights.shape

(32, 63)

In [32]:
sim_matrix = pd.DataFrame(np.matmul(alpha_embedding_weights, output_layer_weights) + \
    np.matmul(output_layer_weights.T, alpha_embedding_weights.T))
sim_matrix

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,53,54,55,56,57,58,59,60,61,62
0,-15.330813,0.469850,0.820668,0.795476,0.582033,0.186281,0.299645,0.493317,0.494980,0.227307,...,0.110086,0.070254,0.306740,0.031656,0.618798,0.076615,-0.008764,0.562992,-0.044297,0.175287
1,0.469850,-7.686265,0.896951,0.836317,0.335659,0.544648,0.010879,0.673423,0.411465,0.153536,...,0.117116,0.002903,-0.053115,-0.065409,0.163817,-0.209091,0.703114,0.340024,-0.268221,-0.586682
2,0.820668,0.896951,-10.672608,0.262622,-0.036154,0.258761,0.519897,0.627843,0.291334,0.287182,...,0.066998,-0.359007,-0.226515,-0.043277,0.240207,-0.199370,0.125615,0.422357,-0.436632,-0.197823
3,0.795476,0.836317,0.262622,-10.256343,0.169370,0.380481,0.502093,0.911784,0.567980,0.541003,...,1.018122,0.504931,-0.333991,-0.023770,0.804823,0.294413,0.178293,0.323833,0.296426,-0.044782
4,0.582033,0.335659,-0.036154,0.169370,-0.616607,-0.047217,0.242716,0.362202,-0.141816,-0.368387,...,0.029793,-0.948857,-0.370922,-0.108187,-0.259749,0.050430,-0.178378,0.020744,-0.075644,0.158275
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
58,0.076615,-0.209091,-0.199370,0.294413,0.050430,0.148731,0.157252,0.068566,-0.020439,-0.004240,...,-3.709913,1.208980,-0.060583,-0.061179,-0.516071,-3.471307,0.124755,-0.336801,0.256702,0.293350
59,-0.008764,0.703114,0.125615,0.178293,-0.178378,0.348261,0.393002,0.534004,-0.094640,0.098490,...,-0.065361,0.009164,0.222825,-0.268717,-0.405140,0.124755,-9.945126,-0.576504,0.307162,-0.157741
60,0.562992,0.340024,0.422357,0.323833,0.020744,-0.157326,-0.281048,0.571727,-0.533834,-0.265351,...,0.032587,-0.132143,-0.422525,-0.172527,-0.128632,-0.336801,-0.576504,-0.598503,-0.140465,-0.518257
61,-0.044297,-0.268221,-0.436632,0.296426,-0.075644,-0.190654,-0.296344,-0.084403,-1.348392,0.083726,...,0.161154,-0.303726,0.077285,0.036804,0.084862,0.256702,0.307162,-0.140465,-0.765523,0.151635


In [47]:
top_5_indices_desc = np.argsort(sim_matrix, axis=1)[:, -5:][:, ::-1]

top_5_dict = {}
for i in range(top_5_indices_desc.shape[0]):
    top_5_dict[i] = list(top_5_indices_desc[i])

In [53]:
products = pd.read_csv('products.csv')
products_dict = {}
for i in range(products.shape[0]):
    products_dict[products['product_id'][i]] = products['product_name'][i]

In [55]:
product_ids = [21903, 30391, 46667, 13176, 21616,  8518, 22935,  5876, 48679,
       24838, 31717, 47209, 26209, 34969, 27966, 37646, 44632, 16797,
       39275,  5077, 10749, 49235, 21137, 28204, 21938, 46979, 47626,
       44359, 34126, 28985, 24852, 41950, 30489,  9076, 24964, 45007,
       42265, 49683, 47766, 39877, 19057, 40706,  5450, 43961, 39928,
       22825, 12341, 17794,  4605, 22035, 27845, 27104, 26604,  8277,
        4920, 25890, 31506, 35951, 45066, 24184, 19660, 27086, 43352]

all_products = []
for i in product_ids:
    all_products.append(products_dict[i])

In [62]:
top_5_dict_items = {}

for k, v in top_5_dict.items():
    key = all_products[k]
    value = [all_products[val] for val in v]
    
    top_5_dict_items[key] = value

In [63]:
top_5_dict_items

{'Organic Baby Spinach': ['Organic Avocado',
  'Banana',
  'Organic Ginger Root',
  'Bag of Organic Bananas',
  'Organic Zucchini'],
 'Organic Cucumber': ['Organic Small Bunch Celery',
  'Organic Hass Avocado',
  'Organic Ginger Root',
  'Bag of Organic Bananas',
  'Organic Grape Tomatoes'],
 'Organic Ginger Root': ['Organic Garlic',
  'Organic Cilantro',
  'Organic Cucumber',
  'Organic Baby Spinach',
  'Michigan Organic Kale'],
 'Bag of Organic Bananas': ['Organic Hass Avocado',
  'Organic Raspberries',
  'Hass Avocados',
  'Organic Large Extra Fancy Fuji Apple',
  'Organic Strawberries'],
 'Organic Baby Arugula': ['Organic Grape Tomatoes',
  'Organic Baby Spinach',
  'Organic Avocado',
  'Large Lemon',
  'Organic Hass Avocado'],
 'Organic Red Onion': ['Organic Cilantro',
  'Small Hass Avocado',
  'Limes',
  'Organic Garlic',
  'Green Bell Pepper'],
 'Organic Yellow Onion': ['Organic Garlic',
  'Organic Small Bunch Celery',
  'Organic Garnet Sweet Potato (Yam)',
  'Apple Honeycrisp O