In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [None]:
import tensorflow as tf

In [2]:
import numpy as np

In [2]:
import sys
sys.path.append("..")

In [3]:
from carca.model import Embeddings
from carca.data import CARCADataset, load_attrs, load_ctx, load_profiles
from carca.utils import get_mask

In [5]:
np.set_printoptions(linewidth=500)

In [6]:
torch.set_printoptions(linewidth=500)

In [16]:
def forward_torch(query, key, value, q_mask, k_mask, d, H):
    query = torch.cat(torch.split(query, d // H, dim=2), dim=0)
    key = torch.cat(torch.split(key, d // H, dim=2), dim=0)
    value = torch.cat(torch.split(value, d // H, dim=2), dim=0)

    mat1, mat2 = q_mask.unsqueeze(1).transpose(1, 2), k_mask.unsqueeze(1)
    attn_mask = torch.bmm(mat1, mat2).bool()
    attn_mask = torch.tile(attn_mask, (H, 1, 1))
    add_mask = torch.where(attn_mask, 0.0, -(2**32) + 1.0)

    out = torch.baddbmm(add_mask, query, key.transpose(1, 2))
    out = out / (d / H) ** 0.5
    out = F.softmax(out, dim=-1)

    weight_mask = torch.tile(q_mask, (H, 1)).unsqueeze(2)
    out = out * weight_mask

    out = torch.bmm(out, value)
    out = torch.cat(torch.split(out, out.shape[0] // H, dim=0), dim=2)

    return out

In [3]:
def forward_tf(query, key, value, H):
    # Split and concat
    Q_ = tf.concat(tf.split(query, H, axis=2), axis=0)  # (h*N, T_q, C/h)
    K_ = tf.concat(tf.split(key, H, axis=2), axis=0)  # (h*N, T_k, C/h)
    V_ = tf.concat(tf.split(value, H, axis=2), axis=0)  # (h*N, T_k, C/h)

    # Multiplication
    outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))  # (h*N, T_q, T_k)

    # Scale
    outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

    # Key Masking
    key_masks = tf.sign(tf.reduce_sum(tf.abs(key), axis=-1))  # (N, T_k)
    key_masks = tf.tile(key_masks, [H, 1])  # (h*N, T_k)
    key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(query)[1], 1])  # (h*N, T_q, T_k)

    paddings = tf.ones_like(outputs) * (-(2**32) + 1)
    outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs)  # (h*N, T_q, T_k)

    # Activation
    outputs = tf.nn.softmax(outputs)  # (h*N, T_q, T_k)

    # Query Masking
    query_masks = tf.sign(tf.reduce_sum(tf.abs(query), axis=-1))  # (N, T_q)
    query_masks = tf.tile(query_masks, [H, 1])  # (h*N, T_q)
    query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(key)[1]])  # (h*N, T_q, T_k)
    outputs *= query_masks  # broadcasting. (N, T_q, C)

    # Weighted sum
    outputs = tf.matmul(outputs, V_)  # ( h*N, T_q, C/h)

    # Restore shape
    outputs = tf.concat(tf.split(outputs, H, axis=0), axis=2)  # (N, T_q, C)

    return outputs

In [4]:
attrs = load_attrs("video_games")
ctx = load_ctx("video_games")
user_ids, item_ids, profiles = load_profiles("video_games")

In [5]:
n_items = len(item_ids) + 1
n_ctx = next(iter(ctx.values())).shape[0]
n_attrs = attrs.shape[1]

In [6]:
train_data = CARCADataset(
    user_ids=user_ids,
    item_ids=item_ids,
    profiles=profiles,
    attrs=attrs,
    ctx=ctx,
    profile_seq_len=50,
    target_seq_len=100,
    mode="train"
)

In [7]:
train_loader = DataLoader(train_data, batch_size=128, shuffle=False)

In [8]:
d = 90
g = 450
H = 3

In [10]:
emb = Embeddings(n_items, d, g, n_ctx, n_attrs)

In [11]:
p_x, p_q, o_x, o_q, y_true = next(iter(train_loader))
p_mask = get_mask(p_x)
p_e = emb.forward(p_x, p_q, p_mask)
query = torch.cat(torch.split(p_e, d // H, dim=2), dim=0)
key = torch.cat(torch.split(p_e, d // H, dim=2), dim=0)
value = torch.cat(torch.split(p_e, d // H, dim=2), dim=0)

In [12]:
key.shape

torch.Size([384, 50, 30])

In [17]:
for i, (p_x, p_q, o_x, o_q, y_true) in enumerate(train_loader):
    p_mask = get_mask(p_x)
    p_e = emb.forward(p_x, p_q, p_mask)
    out = forward_torch(p_e, p_e, p_e, p_mask, p_mask, d, H)
    
    np.savez(f"mha_io/io_{i:03d}.npz", arr_in=p_e.detach().numpy(), arr_out=out.detach().numpy())

In [4]:
import os

In [5]:
io_files = [f for f in os.listdir("mha_io") if os.path.isfile(os.path.join("mha_io", f))]

In [11]:
eps = 1e-5

for f in io_files:
    data = np.load(os.path.join("mha_io", f))
    arr_in = data["arr_in"]
    arr_out = data["arr_out"]

    t_in = tf.convert_to_tensor(arr_in)
    t_out = forward_tf(t_in, t_in, t_in, H)
    t_out = t_out.eval(session=tf.Session())

    assert np.all(np.abs(arr_out - t_out) < eps), f