In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import torch

from util import *

plt.style.use("ggplot")

In [2]:
df = pd.read_json("./data/preprocessed/preprocessed.json", orient="index")
with open("./data/preprocessed/param.json", "r") as f:
    params = json.load(f)
    print("params:", params)

user_n, item_n = params["user_n"], params["item_n"]

params: {'user_n': 1096, 'item_n': 3664}


In [3]:
df.head()

Unnamed: 0,valid_train_purchased_items,valid_eval_purchased_items,test_train_purchased_items,test_eval_purchased_items,valid_train_recommended_items,valid_eval_recommended_items,test_train_recommended_items,test_eval_recommended_items
0,"[3584, 2561, 3586, 3588, 3589, 3076, 522, 2061...","[2561, 516, 1157, 1159, 2061, 1038, 1171, 791,...","[3584, 2561, 3586, 3588, 3589, 3076, 516, 522,...","[2561, 3588, 3589, 2572, 2061, 2960, 277, 2972...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2049, 1, 2, 4, 5, 9, 11, 12, 2059, 13, 18, 20...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3, 4, 8, 10, 11, 16, 19, 32, 33, 34, 37, 42, ..."
1,"[0, 1, 3076, 3591, 2055, 1543, 2571, 2575, 206...","[1029, 1931, 2833, 2836, 2342, 1576, 303, 1969...","[0, 1, 3076, 1029, 3591, 2055, 1543, 2571, 257...","[130, 2055, 392, 137, 2824, 1931, 268, 1296, 2...","[0, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 16, 19, 2...","[2048, 2049, 2, 5, 6, 2055, 9, 2059, 11, 13, 1...","[0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16...","[2048, 1, 2049, 3, 5, 6, 7, 8, 2058, 2059, 19,..."
2,"[5, 2055, 2057, 2058, 2060, 2065, 2067, 2068, ...","[5, 2577, 1554, 531, 2068, 2075, 1564, 2082, 1...","[5, 2055, 2057, 2058, 2060, 2065, 2067, 2068, ...","[517, 2058, 1035, 2065, 2577, 1043, 2068, 531,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2, 5, 12, 13, 26, 27, 35, 42, 44, 47, 57, 58,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2, 3, 7, 10, 11, 19, 26, 29, 30, 33, 34, 41, ..."
3,"[1536, 3073, 4, 2570, 2571, 2060, 524, 1551, 3...","[2304, 2695, 906, 2060, 2957, 2446, 527, 1295,...","[1536, 3073, 4, 2570, 2571, 2060, 524, 1551, 5...","[906, 2446, 21, 802, 1061, 936, 937, 45, 2222,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 16...","[2049, 1, 2, 5, 2054, 2055, 11, 12, 2064, 2067...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 16...","[2049, 3, 6, 2055, 2054, 2060, 2068, 2071, 25,..."
4,"[2571, 2069, 1564, 1054, 34, 2101, 2623, 2626,...","[1923, 3076, 1159, 2187, 1296, 1044, 1818, 924...","[3076, 2571, 1044, 2069, 1564, 1054, 34, 2101,...","[9, 143, 2071, 1176, 2331, 2086, 3116, 2222, 8...","[0, 6, 13, 26, 27, 29, 30, 31, 33, 34, 35, 38,...","[2049, 1029, 5, 1033, 1036, 2067, 2068, 2069, ...","[0, 5, 6, 13, 26, 27, 29, 30, 31, 33, 34, 35, ...","[1, 2049, 3, 1026, 6, 2055, 3078, 11, 2059, 30..."


In [10]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.positional_embedding = torch.nn.Embedding(
            num_embeddings=max_len, embedding_dim=d_model
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        i = torch.LongTensor(list(range(x.size(-2))))
        x = x + self.positional_embedding.forward(i)
        return x

In [43]:
class Model(torch.nn.Module):
    def __init__(self, d_model: int, user_n: int, item_n: int):
        super().__init__()
        self.user_embedding = torch.nn.Embedding(
            num_embeddings=user_n, embedding_dim=d_model
        )
        self.item_embedding = torch.nn.Embedding(
            num_embeddings=item_n + 1, embedding_dim=d_model
        )
        self.positional_encoding = PositionalEncoding(d_model=d_model)
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True
            ),
            num_layers=2,
        )

    def forward(
        self,
        u: torch.LongTensor,
        context_i_list: torch.LongTensor,
        target_i_list: torch.LongTensor,
    ):
        eu = self.user_embedding.forward(u)
        ei = self.item_embedding.forward(context_i_list)
        # pi = self.positional_encoding.forward(ei)
        self.transformer_encoder.forward(torch.cat([eu.unsqueeze(-2), pi], dim=-2))

In [44]:
args = Args(batch_size=10)
rnd = np.random.RandomState(args.seed)

context_size = 10

model = Model(d_model=args.d, user_n=user_n, item_n=item_n)
sampler = UpliftBasedPointwiseSampler(user_n=user_n, item_n=item_n, df=df)

u_list, target_i_list, r_ui_list = sampler.sample(
    rnd=rnd,
    args=args,
)

cls_token = item_n
context_i_list = []

for u in u_list:
    context_i = [cls_token] + rnd.choice(
        df.iloc[u][f"{args.mode}_train_purchased_items"], size=context_size
    ).tolist()
    print(context_i)
    context_i_list.append(context_i)

print(context_i_list)
model.forward(
    u=torch.LongTensor(u_list),
    context_i_list=torch.LongTensor(context_i_list),
    target_i_list=torch.LongTensor(target_i_list),
)

[3664, 723, 200, 472, 646, 3309, 993, 1345, 1693, 233, 1030]
[3664, 3439, 446, 2726, 884, 2156, 45, 2289, 3443, 307, 3190]
[3664, 2174, 1699, 3144, 3453, 212, 1277, 1587, 462, 3188, 2361]
[3664, 1704, 1211, 1643, 2277, 1691, 2422, 987, 3173, 753, 251]
[3664, 1368, 1969, 2790, 1985, 3425, 1091, 1617, 1966, 1273, 2150]
[3664, 3061, 73, 2330, 2850, 2986, 3317, 2780, 378, 2986, 3584]
[3664, 2876, 2660, 3202, 1292, 1931, 3364, 1140, 3383, 1487, 894]
[3664, 1, 3482, 3363, 170, 1688, 3425, 2153, 630, 420, 3379]
[3664, 3364, 686, 370, 1264, 1486, 2697, 998, 550, 625, 3237]
[3664, 3562, 2681, 1064, 254, 2690, 1308, 553, 1539, 1177, 1779]
[[3664, 723, 200, 472, 646, 3309, 993, 1345, 1693, 233, 1030], [3664, 3439, 446, 2726, 884, 2156, 45, 2289, 3443, 307, 3190], [3664, 2174, 1699, 3144, 3453, 212, 1277, 1587, 462, 3188, 2361], [3664, 1704, 1211, 1643, 2277, 1691, 2422, 987, 3173, 753, 251], [3664, 1368, 1969, 2790, 1985, 3425, 1091, 1617, 1966, 1273, 2150], [3664, 3061, 73, 2330, 2850, 2986, 331