In [2]:
import sys
import argparse
import logging
import os
import torch
from torch import optim, nn
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random
from datetime import datetime

from dataset import load_data, TrainDataset, EvalDataset
from dataloader import DataLoaderHandler
from model import SIA
from trainer import train, evaluate

import torch.nn.functional as F
import itertools
from tqdm import tqdm

from einops import repeat

BADEDIR = "./"

In [3]:
parser = argparse.ArgumentParser()
args = parser.parse_args([])
args.device = "cuda" if torch.cuda.is_available() else "cpu"
args.dataset = "amazon_beauty"
args.maxlen = 50
args.batch_size = 64
args.content = ["image", "desc"]
args.latent_dim = 128
args.item_num_outputs = 8
args.item_num_heads = 8
args.item_num_latents = 8
args.item_dim_hidden = 64
args.attn_depth = 5
args.attn_self_per_cross = 2
args.attn_dropout = 0.0
args.attn_ff_dropout = 0.2
args.attn_num_heads = 8
args.attn_dim_head = 64
args.eval_sample_mode = "uni"
args.lr = 1e-3
args.weight_decay = 0.
args.num_epochs = 100
args.lr_milestones = None
args.early_stop = 5

logger = logging.getLogger()

In [4]:
raw_dir = "./dataset/raw/"
processed_dir = "./dataset/processed/"

inter, item_feats, pop = load_data(args, raw_dir, processed_dir, logger)
dim_item_feats = [tuple(feat.values())[0].shape[0] for feat in item_feats]
num_items = len(pop)

In [5]:
train_dataset = TrainDataset(inter, item_feats, args, logger)
val_dataset = EvalDataset(inter, item_feats, pop, args, logger, mode="val", eval_mode=args.eval_sample_mode)
test_dataset = EvalDataset(inter, item_feats, pop, args, logger, mode="test", eval_mode=args.eval_sample_mode)
train_loader = DataLoaderHandler("train", train_dataset, args, logger).get_dataloader()
val_loader = DataLoaderHandler("val", val_dataset, args, logger).get_dataloader()
test_loader = DataLoaderHandler("test", test_dataset, args, logger).get_dataloader()

In [6]:
model = SIA(
    latent_dim=args.latent_dim,
    item_num_outputs=args.item_num_outputs,
    item_num_heads=args.item_num_heads,
    item_num_latents=args.item_num_latents,
    item_dim_hidden=args.item_dim_hidden,
    attn_depth=args.attn_depth,
    attn_self_per_cross=args.attn_self_per_cross,
    attn_dropout=args.attn_dropout,
    attn_ff_dropout=args.attn_ff_dropout,
    attn_num_heads=args.attn_num_heads,
    attn_dim_head=args.attn_dim_head,
    dim_item_feats=dim_item_feats,
    num_items=num_items,
    maxlen=args.maxlen,
    device=args.device
)

In [15]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"********** {name} **********")
        print(param.data)

********** set_transformers.0.enc.0.mab.fc_q.weight **********
tensor([[-3.1800e-02,  2.4777e-02,  3.6600e-02,  ..., -1.4485e-03,
         -1.4242e-02, -6.2072e-05],
        [ 3.0920e-02,  1.8889e-03, -2.3799e-02,  ..., -3.1601e-02,
          1.9452e-02,  1.9472e-05],
        [ 3.6784e-03,  3.7307e-02,  2.2342e-02,  ..., -2.0370e-02,
          2.3144e-02,  1.2919e-03],
        ...,
        [ 3.2521e-02, -1.6560e-02,  1.3358e-02,  ...,  5.5847e-03,
          2.3069e-02, -2.8192e-02],
        [ 2.1246e-02, -1.7563e-02,  4.3314e-02,  ..., -3.9301e-02,
          2.0351e-02,  2.2465e-02],
        [ 1.0496e-02, -7.0758e-03,  3.4191e-02,  ...,  1.7443e-02,
         -2.3627e-02,  6.7963e-03]])
********** set_transformers.0.enc.0.mab.fc_q.bias **********
tensor([-0.0224,  0.0040, -0.0140,  0.0186, -0.0415, -0.0381, -0.0406, -0.0239,
        -0.0297, -0.0318,  0.0373,  0.0380,  0.0352,  0.0184,  0.0399,  0.0279,
         0.0155,  0.0324, -0.0245,  0.0383, -0.0370, -0.0211, -0.0289, -0.0144,
    

In [36]:
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
milestones = args.lr_milestones if args.lr_milestones else []
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)

In [86]:
class WarmupBeforeMultiStepLR(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, warmup_steps=None, milestones=None, gamma=None, last_epoch=-1):
        self.gamma = 1
        def lr_lambda(step):
            if warmup_steps and step < warmup_steps:
                return step / warmup_steps
            if milestones and gamma and step in milestones:
                self.gamma *= gamma
            return self.gamma

        super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)        

In [100]:
optimizer.param_groups[0]["lr"]

1.0000000000000003e-05

In [6]:
model = SimpleModel(num_items=num_items, maxlen=args.maxlen, device=args.device)

In [7]:
model.load_state_dict(torch.load("./saved/SIMPLE.pt", map_location=torch.device('cpu')))

<All keys matched successfully>

In [9]:
test_metrics = evaluate(model, test_loader, args.eval_sample_mode, num_items)

100%|██████████| 548/548 [00:12<00:00, 45.03it/s]


In [11]:
test_metrics

{'NDCG@1': 0.10248367504063417,
 'NDCG@5': 0.1859364205755028,
 'NDCG@10': 0.21774015628144958,
 'HR@1': 0.10248367504063417,
 'HR@5': 0.2641934472040834,
 'HR@10': 0.36294162935926316}

In [6]:
def get_writer(dataset_name):
    log_dir = os.path.join("./", f"log_tensorboard/{dataset_name}")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    writer = SummaryWriter(log_dir)
    return writer


model = model.to(args.device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
milestones = args.lr_milestones if args.lr_milestones else [int(args.num_epochs*0.8), int(args.num_epochs*0.9)]
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
writer = get_writer(args.dataset)

2023-02-08 15:35:48.163737: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-08 15:35:48.336041: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-08 15:35:49.366745: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-08 15:35:49.366814: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

In [7]:
model = train(
    args.num_epochs, 
    args.early_stop, 
    train_loader, 
    val_loader, 
    args.eval_sample_mode,
    num_items,
    model, 
    optimizer, 
    scheduler, 
    loss_fn, 
    writer, 
    logger
)

Epoch 0 - loss: 11.131933212280273:   0%|          | 1/548 [00:05<48:53,  5.36s/it]


KeyboardInterrupt: 

In [8]:
torch.tensor([1, 2, 3]).device

device(type='cpu')

In [7]:
evaluate(model, test_loader, "uni", num_items)

100%|██████████| 548/548 [05:06<00:00,  1.79it/s]


{'NDCG@1': 0.0,
 'NDCG@5': 2.456166868033892e-05,
 'NDCG@10': 4.138834934897614e-05,
 'HR@1': 0.0,
 'HR@5': 5.7030425732128093e-05,
 'HR@10': 0.00011406085146425619}

In [12]:
dim_item_feats = [tuple(feat.values())[0].shape[0] for feat in item_feats]
num_items = len(pop)

In [13]:
model = SIA(
    latent_dim=args.latent_dim,
    item_num_outputs=args.item_num_outputs,
    item_num_heads=args.item_num_heads,
    item_num_latents=args.item_num_latents,
    item_dim_hidden=args.item_dim_hidden,
    attn_depth=args.attn_depth,
    attn_self_per_cross=args.attn_self_per_cross,
    attn_dropout=args.attn_dropout,
    attn_ff_dropout=args.attn_ff_dropout,
    attn_num_heads=args.attn_num_heads,
    attn_dim_head=args.attn_dim_head,
    dim_item_feats=dim_item_feats,
    num_items=num_items,
    maxlen=args.maxlen,
)

In [14]:
# INPUT
# seq_list: (B, N)
# next_item_list: (B,)
# item_feat_list: (B, N, d)
seq_list, pos_list, next_item_list, *item_feat_lists = next(iter(train_loader))

In [15]:
x = (seq_list, pos_list, *item_feat_lists)
logits = model(x)

In [16]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
loss_fn(logits, next_item_list)

tensor(11.0084, grad_fn=<NllLossBackward0>)

In [17]:
# INPUT
# seq_list: (B, N)
# next_item_list: (B,)
# item_feat_list: (B, N, d)
seq_list, pos_list, next_item_list, candidate_list, *item_feat_lists = next(iter(val_loader))

In [18]:
x = (seq_list, pos_list, *item_feat_lists)
logits = model(x)

In [19]:
logits.shape

torch.Size([64, 50302])

In [70]:
scores = logits if args.eval_sample_mode == "full" else logits.gather(dim=1, index=candidate_list)

In [71]:
rank = (-scores).argsort(dim=1)
cut = rank[:, :10]

In [72]:
next_item_list.shape

torch.Size([64])

In [74]:
one_hot_label = F.one_hot(next_item_list, num_classes=(num_items + 1))

In [78]:
hits = one_hot_label.gather(dim=1, index=cut)

In [83]:
hits.sum().item()

0

In [84]:
position = torch.arange(2, 2 + 10)

In [85]:
weights = 1 / torch.log2(position.float())

In [86]:
dcg = (hits * weights).sum(1)

In [None]:
hits_sum = hits.sum()
ndcg_sum = dcg.sum()

In [88]:
import itertools

In [90]:
[(a, b) for a, b in itertools.product(["NDCG", "HR"], [1, 5, 10])]

[('NDCG', 1), ('NDCG', 5), ('NDCG', 10), ('HR', 1), ('HR', 5), ('HR', 10)]

In [None]:
# # Arguments
# item_feat_dims = [tuple(feat.values())[0].shape[0] for feat in item_feats]
# n_items = len(pop)

# # LAYERS
# set_transformers = nn.ModuleList([
#     SetTransformer(
#         dim_input=feat_dim,
#         num_outputs=args.item_num_outputs,
#         dim_output=args.latent_dim,
#         num_inds=args.item_num_latents,
#         dim_hidden=args.item_dim_hidden,
#         num_heads=args.item_num_heads,
#         ln=True,
#     ) for feat_dim in item_feat_dims
# ])
# id_embedding = nn.Embedding(
#     num_embeddings=(n_items + 1),
#     embedding_dim=args.latent_dim,
#     device=args.device,
#     padding_idx=0
# )
# pos_embedding = nn.Embedding(
#     num_embeddings=(args.maxlen + 1),
#     embedding_dim=args.latent_dim,
#     device=args.device,
#     padding_idx=0
# )
# # id_set_transformer = SetTransformer(
# #     dim_input=args.emb_size,
# #     num_outputs=16,
# #     dim_output=args.emb_size,
# #     num_inds=16,
# #     dim_hidden=args.st_dim_hidden,
# #     num_heads=args.st_num_heads,
# #     ln=True
# # )

# get_latent_attn = lambda: PreNorm(
#     args.latent_dim,
#     Attention(
#         query_dim=args.latent_dim,
#         heads=args.attn_num_heads,
#         dim_head=args.attn_dim_head,
#         dropout=args.attn_dropout
#     ),
# )

# get_cross_attn = lambda: PreNorm(
#     args.latent_dim, 
#     Attention(
#         query_dim=args.latent_dim,
#         context_dim=args.latent_dim,
#         heads=args.attn_num_heads,
#         dim_head=args.attn_dim_head,
#         dropout=args.attn_dropout
#     ), 
#     context_dim=args.latent_dim
# )

# get_latent_ff = lambda: PreNorm(
#     args.latent_dim,
#     FeedForward(
#         args.latent_dim, 
#         dropout=args.attn_ff_dropout
#     )
# )

# get_cross_ff = lambda: PreNorm(
#     args.latent_dim,
#     FeedForward(
#         args.latent_dim, 
#         dropout=args.attn_ff_dropout
#     )
# )

# layers = nn.ModuleList([])
# for _ in range(args.attn_depth):
#     self_attns = nn.ModuleList([])
#     for _ in range(args.attn_self_per_cross):
#         self_attns.append(nn.ModuleList([
#             get_latent_attn(),
#             get_latent_ff(),            
#         ]))
#     layers.append(nn.ModuleList([
#         get_cross_attn(),
#         get_cross_ff(),
#         self_attns
#     ]))
        
# to_logits = nn.Sequential(
#     Reduce('b n d -> b d', 'mean'),
#     nn.LayerNorm(args.latent_dim),
#     nn.Linear(args.latent_dim, n_items + 1)
# )

# # FORWARD

# id_emb = id_embedding(seq_list)
# pos_emb = pos_embedding(pos_list)
# x = id_emb + pos_emb

# item_feat = []
# for set_transformer, item_feat_list in zip(set_transformers, item_feat_lists):
#     out = [set_transformer(feat.unsqueeze(0)) for feat in item_feat_list]
#     out = torch.cat(out)
#     item_feat.append(out)
# item_feat = torch.cat(item_feat, dim=1)

# mask_latent = repeat(pos_list, 'b n -> b n d', d=args.latent_dim).float()
# mask_items = torch.ones(item_feat.shape)
# mask_cross_attn = einsum("b i d, b j d -> b i j", mask_latent, mask_items) > 0
# mask_self_attn = einsum("b i d, b j d -> b i j", mask_latent, mask_latent) > 0

# for cross_attn, cross_ff, self_attns in layers:
#     x = cross_attn(x, context=item_feat, mask=mask_cross_attn) + x
#     x = cross_ff(x) + x
    
#     for self_attn, self_ff in self_attns:
#         x = self_attn(x, mask=mask_self_attn) + x
#         x = self_ff(x) + x

# x = to_logits(x)

In [None]:
for seq_list, pos_list, next_item_list, *item_feat_list in tqdm(train_loader):
    id_emb = id_embedding(seq_list)
    pos_emb = pos_embedding(pos_list)
    x = id_emb + pos_emb

    item_feat = []
    for set_transformer, item_feat_list in zip(set_transformers, item_feat_lists):
        out = [set_transformer(feat.unsqueeze(0)) for feat in item_feat_list]
        out = torch.cat(out)
        item_feat.append(out)
    item_feat = torch.cat(item_feat, dim=1)

    mask_latent = repeat(pos_list, 'b n -> b n d', d=args.latent_dim).float()
    mask_items = torch.ones(item_feat.shape)
    mask_cross_attn = einsum("b i d, b j d -> b i j", mask_latent, mask_items) > 0
    mask_self_attn = einsum("b i d, b j d -> b i j", mask_latent, mask_latent) > 0

    for cross_attn, cross_ff, self_attns in layers:
        x = cross_attn(x, context=item_feat, mask=mask_cross_attn) + x
        x = cross_ff(x) + x
        
        for self_attn, self_ff in self_attns:
            x = self_attn(x, mask=mask_self_attn) + x
            x = self_ff(x) + x

    x = to_logits(x)
    loss = cross_entropy(x, next_item_list)
    loss.backward()