In [1]:
from dataset import load_data, TrainDataset, EvalDataset
from dataloader import DataLoaderHandler
from model import SetTransformer
import logging
import argparse

import torch
import torch.nn as nn
from torch import einsum
import pandas as pd
import numpy as np
import pickle

from tqdm import tqdm

In [2]:
from model.modules import PreNorm, Attention, FeedForward
from einops import rearrange, repeat
from einops.layers.torch import Reduce

In [3]:
parser = argparse.ArgumentParser()
args = parser.parse_args([])
args.device = "cuda" if torch.cuda.is_available() else "cpu"
args.dataset = "amazon_beauty"
args.maxlen = 50
args.batch_size = 64
args.content = ["image", "desc"]
logger = logging.getLogger()


args.latent_dim = 128

args.item_num_outputs = 8
args.item_num_heads = 8
args.item_num_latents = 8
args.item_dim_hidden = 64

args.attn_num_iters = 5
args.attn_self_per_cross = 2
args.attn_dropout = 0.0
args.attn_num_heads = 8
args.attn_dim_head = 64

In [4]:
raw_dir = "./dataset/raw/"
processed_dir = "./dataset/processed/"

inter, item_feats, pop = load_data(args, raw_dir, processed_dir, logger)
train_dataset = TrainDataset(inter, item_feats, args, logger)

In [5]:
train_loader = DataLoaderHandler("train", train_dataset, args, logger)
train_loader = train_loader.get_dataloader()

In [6]:
# Arguments
item_feat_dims = [tuple(feat.values())[0].shape[0] for feat in item_feats]
n_items = len(pop)

# LAYERS
set_transformers = [
    SetTransformer(
        dim_input=feat_dim,
        num_outputs=args.item_num_outputs,
        dim_output=args.latent_dim,
        num_inds=args.item_num_latents,
        dim_hidden=args.item_dim_hidden,
        num_heads=args.item_num_heads,
        ln=True,
    ) for feat_dim in item_feat_dims
]
id_embedding = nn.Embedding(
    num_embeddings=(n_items + 1),
    embedding_dim=args.latent_dim,
    device=args.device,
    padding_idx=0
)
pos_embedding = nn.Embedding(
    num_embeddings=(args.maxlen + 1),
    embedding_dim=args.latent_dim,
    device=args.device,
    padding_idx=0
)
# id_set_transformer = SetTransformer(
#     dim_input=args.emb_size,
#     num_outputs=16,
#     dim_output=args.emb_size,
#     num_inds=16,
#     dim_hidden=args.st_dim_hidden,
#     num_heads=args.st_num_heads,
#     ln=True
# )

In [7]:
# INPUT
# seq_list: (B, N)
# next_item_list: (B,)
# item_feat_list: (B, N, d)
seq_list, pos_list, next_item_list, *item_feat_lists = next(iter(train_loader))

In [8]:
id_emb = id_embedding(seq_list)
pos_emb = pos_embedding(pos_list)
latents = id_emb + pos_emb

In [9]:
item_feat = []
for set_transformer, item_feat_list in zip(set_transformers, item_feat_lists):
    out = [set_transformer(feat.unsqueeze(0)) for feat in item_feat_list]
    out = torch.cat(out)
    item_feat.append(out)
item_feat = torch.cat(item_feat, dim=1)

In [10]:
mask_latent = repeat(pos_list, 'b n -> b n d', d=args.latent_dim).float()
mask_items = torch.ones(item_feat.shape)
mask_cross_attn = einsum("b i d, b j d -> b i j", mask_latent, mask_items) > 0
mask_self_attn = einsum("b i d, b j d -> b i j", mask_latent, mask_latent) > 0

In [11]:
self_attn_beg = Attention(
    query_dim=args.latent_dim,
    heads=args.attn_num_heads,
    dim_head=args.attn_dim_head,
    dropout=args.attn_dropout
)

self_attn = Attention(
    query_dim=args.latent_dim,
    heads=args.attn_num_heads,
    dim_head=args.attn_dim_head,
    dropout=args.attn_dropout
)

cross_attn = PreNorm(
    args.latent_dim, 
    Attention(
        query_dim=args.latent_dim,
        context_dim=args.latent_dim,
        heads=args.attn_num_heads,
        dim_head=args.attn_dim_head,
        dropout=args.attn_dropout
    ), 
    context_dim=args.latent_dim
)

In [12]:
out = self_attn(latents, mask=mask_self_attn)
out.shape

torch.Size([64, 50, 128])

In [12]:
out = cross_attn(latents, context=item_feat, mask=mask_cross_attn)
out.shape

TypeError: <lambda>() got an unexpected keyword argument 'context'

In [None]:
q = embedded
k = item_feat_concat
v = item_feat_concat
print(q.shape)
print(k.shape)
print(v.shape)

In [None]:
dim_q = q.shape[-1]
dim_k = k.shape[-1]
dim_v = v.shape[-1]
num_heads = 8
dim_head = 64
inner_dim = dim_head * num_heads
scale = dim_head ** -0.5

In [None]:
fc_q = nn.Linear(dim_q, inner_dim, bias=False)
fc_k = nn.Linear(dim_k, inner_dim, bias=False)
fc_v = nn.Linear(dim_v, inner_dim, bias=False)

dropout = nn.Dropout(0.2)
to_out = nn.Linear(inner_dim, dim_q)

In [None]:
q = fc_q(q)
k = fc_k(k)
v = fc_v(v)
print(q.shape)
print(k.shape)
print(v.shape)

In [None]:
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=num_heads), (q, k, v))
print(q.shape)
print(k.shape)
print(v.shape)

In [None]:
sim = einsum("b i d, b j d -> b i j", q, k) * scale
sim.shape

In [None]:
seq_list.shape

In [None]:
mask_q = repeat(seq_list, 'b n -> b n d', d=args.emb_size).float()
mask_kv = torch.ones(item_feat_concat.shape)
mask = einsum("b i d, b j d -> b i j", mask_q, mask_kv) > 0

In [None]:
mask.shape

In [None]:
mask = repeat(mask, "b i j -> (b h) i j", h=num_heads)

In [None]:
max_neg_value = -torch.finfo(sim.dtype).max

In [None]:
sim

In [None]:
sim.masked_fill_(~mask, max_neg_value)

In [None]:
attn = sim.softmax(dim=-1)

In [None]:
attn.shape

In [None]:
attn = dropout(attn)

In [None]:
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)

In [None]:
out.shape

In [None]:
to_out(out).shape

In [None]:
for seq_list, next_item_list, *item_feat_lists in tqdm(train_loader):
    item_out = []
    for set_transformer, item_feat_list in zip(set_transformers, item_feat_lists):
        out = [set_transformer(item_feat.unsqueeze(0)) for item_feat in item_feat_list]
        out = torch.cat(out)
        item_out.append(out)
    item_out = torch.cat(item_out, dim=1)

    seq_out = [id_set_transformer(embedding(seq).unsqueeze(0)) for seq in seq_list]
    seq_out = torch.cat(seq_out)

In [None]:
n_items = len(pop)
embedding = nn.Embedding(
    num_embeddings=(n_items + 1),
    embedding_dim=args.emb_size,
    padding_idx=0,
    device=args.device
)

embedded = embedding(seq_list)

In [None]:
padded = (seq_list == 0)
attn_mask = (padded.unsqueeze(2) | padded.unsqueeze(1)).repeat(args.ia_num_heads, 1, 1)
attn_output, _ = ma(query=embedded, key=embedded, value=embedded, attn_mask=attn_mask)

In [None]:
item_feat_mask = torch.full(item_feat_concat.shape, False)[:, :, 0]
attn_mask = padded.unsqueeze(2) | item_feat_mask.unsqueeze(1)
attn_mask = attn_mask.repeat(args.ia_num_heads, 1, 1)
attn_output, _ = ma(query=attn_output, key=item_feat_concat, value=item_feat_concat, attn_mask=attn_mask)