In [2]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
# torch.manual_seed(1)

In [4]:
sys.path.append("..")

In [5]:
torch.set_printoptions(linewidth=500, edgeitems=20, precision=4, sci_mode=False)

In [6]:
from carca.data import CARCADataset, load_attrs, load_ctx, load_profiles
from carca.model import MultiHeadAttention
from carca.utils import get_mask

In [7]:
attrs = load_attrs("video_games")
ctx = load_ctx("video_games")
user_ids, item_ids, profiles = load_profiles("video_games")

In [8]:
n_items = len(item_ids) + 1
n_ctx = next(iter(ctx.values())).shape[0]
n_attrs = attrs.shape[1]

In [9]:
# Hyper-parameters for Games dataset
learning_rate =  0.0001
seq_len = 50
n_blocks = 3
n_heads = 3
dropout_rate = 0.5
l2_reg = 0.0
d_dim = 90
g_dim = 450
residual_sa = True
residual_ca = True
epochs = 800
batch_size = 128
beta1 = 0.9
beta2 = 0.98

In [10]:
train_data = CARCADataset(
    user_ids=user_ids,
    item_ids=item_ids,
    profiles=profiles,
    attrs=attrs,
    ctx=ctx,
    profile_seq_len=10,
    target_seq_len=100,
    mode="train"
)
val_data = CARCADataset(
    user_ids=user_ids,
    item_ids=item_ids,
    profiles=profiles,
    attrs=attrs,
    ctx=ctx,
    profile_seq_len=10,
    target_seq_len=100,
    mode="val"
)

train_loader = DataLoader(train_data, batch_size=2, shuffle=False, num_workers=0)
val_loader = DataLoader(val_data, batch_size=2, shuffle=False, num_workers=0)

In [11]:
d_dim = 6
g_dim = 14
n_heads = 2
dropout_rate = 0.0

In [12]:
items_embed = nn.Embedding(num_embeddings=n_items, embedding_dim=d_dim, padding_idx=0)
feats_embed = nn.Linear(in_features=n_ctx + n_attrs, out_features=g_dim)
joint_embed = nn.Linear(in_features=g_dim + d_dim, out_features=d_dim)

In [13]:
# Attention
norm1 = nn.LayerNorm(normalized_shape=d_dim)
attention_sa = nn.MultiheadAttention(embed_dim=d_dim, num_heads=n_heads, dropout=dropout_rate, batch_first=True)

# FFN
norm2 = nn.LayerNorm(normalized_shape=d_dim)
ffn_1 = nn.Conv1d(in_channels=d_dim, out_channels=d_dim, kernel_size=1)
activation = nn.LeakyReLU()
dropout2 = nn.Dropout(p=dropout_rate)

ffn_2 = nn.Conv1d(in_channels=d_dim, out_channels=d_dim, kernel_size=1)
dropout3 = nn.Dropout(p=dropout_rate)

In [14]:
attn_own = MultiHeadAttention(embed_dim=d_dim, num_heads=n_heads, dropout=0.0)

In [15]:
# Attention
attention_ca = nn.MultiheadAttention(embed_dim=d_dim, num_heads=n_heads, batch_first=True)

# FFN
ffn_ca = nn.Conv1d(in_channels=d_dim, out_channels=1, kernel_size=1)
sig = nn.Sigmoid()

In [16]:
dropout_m = nn.Dropout(p=dropout_rate)
norm_m = nn.LayerNorm(normalized_shape=d_dim)

In [17]:
loader_iter = iter(train_loader)

In [18]:
p_x, p_ac, o_x, o_ac, y_true = next(loader_iter)

In [19]:
p_mask, o_mask = get_mask(p_x), get_mask(o_x)

In [20]:
p_z = items_embed.forward(p_x)
p_q = feats_embed.forward(p_ac)
p_e = joint_embed.forward(torch.cat((p_z, p_q), dim=-1))
p_e = p_e * p_mask.unsqueeze(2)

In [20]:
# p_e = dropout_m.forward(p_e)

In [21]:
# mat1, mat2 = p_mask.unsqueeze(1).permute(0, 2, 1).float(), p_mask.unsqueeze(1).float()
# sa_mask = torch.bmm(mat1, mat2)
# sa_mask = sa_mask == 0.0
# sa_mask = torch.where(sa_mask == 0.0, -1e6, 0.0)
# sa_mask = torch.repeat_interleave(sa_mask, n_heads, 0)

In [22]:
sa_mask = p_mask == 0.0

In [23]:
mat1, mat2 = p_mask.unsqueeze(1).permute(0, 2, 1).float(), p_mask.unsqueeze(1).float()
attn_mask = torch.bmm(mat1, mat2).bool()
# attn_mask = torch.where(attn_mask == 0.0, -1e6, 0.0)
attn_mask = torch.tile(attn_mask, (n_heads, 1, 1))

In [24]:
#mat1, mat2 = o_mask.unsqueeze(1).permute(0, 2, 1).float(), p_mask.unsqueeze(1).float()
#attn_mask = torch.bmm(mat1, mat2).bool()
#attn_mask = torch.tile(attn_mask, (n_heads, 1, 1))

In [25]:
p_query = norm1.forward(p_e)

In [26]:
p_s = attn_own.forward(p_query, p_e, p_e, attn_mask=attn_mask)

In [27]:
# p_s, p_w = attention_sa.forward(p_query, p_e, p_e, key_padding_mask=sa_mask, need_weights=True, average_attn_weights=False)

In [28]:
# p_ss = attn_own.forward(p_query, p_e, p_e, attn_mask=attn_mask)

In [29]:
p_s = p_s * p_mask.unsqueeze(2)

In [30]:
p_s = torch.mul(p_s, p_query)

In [31]:
p_s = norm2.forward(p_s)
p_f = p_s.transpose(1, 2)

In [32]:
p_f = ffn_1.forward(p_f)

In [33]:
p_f = activation.forward(p_f)

In [34]:
p_f = dropout2.forward(p_f)

In [35]:
p_f = ffn_2.forward(p_f)
p_f = dropout3.forward(p_f)
p_f = p_f.transpose(1, 2)

In [36]:
p_f = p_f * p_mask.unsqueeze(2)

In [22]:
e_z = items_embed.forward(o_x)
e_q = feats_embed.forward(o_ac)
e_e = joint_embed.forward(torch.cat((e_z, e_q), dim=-1))
e_e = e_e * o_mask.unsqueeze(2)

In [44]:
s, _ = attention_ca.forward(e_e, p_f, p_f, key_padding_mask=p_mask == 0, need_weights=True)

In [45]:
s = s * o_mask.unsqueeze(2)

In [46]:
s = torch.mul(s, e_e)

In [41]:
# s = s.permute(0, 2, 1)

In [42]:
ffn_ca = nn.Linear(in_features=d_dim, out_features=1)

In [47]:
s.shape

torch.Size([2, 20, 6])

In [48]:
y = ffn_ca.forward(s)

In [49]:
y.shape

torch.Size([2, 20, 1])

In [43]:
y = sig.forward(y)

In [44]:
y = y.squeeze()

In [45]:
eps = 1e-8
loss = -(y_true * torch.log(y + eps) + (1.0 - y_true) * torch.log(1.0 - y + eps))

In [46]:
loss = torch.sum(loss * o_mask) / torch.sum(o_mask)

In [47]:
p_e.retain_grad()
p_s.retain_grad()
p_f.retain_grad()
e_e.retain_grad()
s.retain_grad()
y.retain_grad()

In [48]:
loss.backward()

In [21]:
softmax = nn.Softmax(dim=-1)

In [59]:
query = torch.concat(torch.split(e_e, d_dim // n_heads, dim=2), dim=0)
key = torch.concat(torch.split(p_e, d_dim // n_heads, dim=2), dim=0)
value = torch.concat(torch.split(p_e, d_dim // n_heads, dim=2), dim=0)

In [60]:
mat1, mat2 = o_mask.unsqueeze(1).permute(0, 2, 1).float(), p_mask.unsqueeze(1).float()
attn_mask = torch.bmm(mat1, mat2).bool()
attn_mask = torch.tile(attn_mask, (n_heads, 1, 1))

In [61]:
add_mask = torch.where(attn_mask, 0.0, -1e6)

In [63]:
query_mask = torch.tile(o_mask, (n_heads, 1)).unsqueeze(2)

In [56]:
out = torch.bmm(query, key.transpose(1, 2))

In [64]:
out = torch.baddbmm(add_mask, query, key.transpose(1, 2))

In [65]:
out /= (d_dim / n_heads) ** 0.5

In [66]:
out = softmax.forward(out)

In [74]:
out = out * query_mask

In [76]:
out = torch.bmm(out, value)

In [78]:
out = torch.concat(torch.split(out, out.shape[0] // n_heads, dim=0), dim=2)

In [83]:
out[1]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.7562, -0.9502,  1.8173,  1.1411,  0.9791,  0.5281],
        [ 2.7138, -1.2549,  2.7758,  1.4265,  0.9230,  0.4740],
        [ 2.9417, -1.3218,  2.9812,  1.6699,  1.3413,  0.5687],
        [ 2.1356, -1.0711,  2.1962,  1.2098,  1.0064,  0.5241],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0

In [242]:
mat1, mat2 = o_mask.unsqueeze(1).permute(0, 2, 1).float(), p_mask.unsqueeze(1).float()
attn_mask = torch.bmm(mat1, mat2).bool()

In [291]:
linear = nn.Linear(in_features=10, out_features=1)
conv1d = nn.Conv1d(in_channels=10, out_channels=1, kernel_size=1)

In [287]:
in_t = torch.randn((4, 10, 10))

In [288]:
out_t = linear.forward(in_t)

In [289]:
out_t.shape

torch.Size([4, 10, 1])

In [293]:
conv1d.forward(in_t).shape

torch.Size([4, 1, 10])

In [297]:
conv1d.bias

Parameter containing:
tensor([0.11], requires_grad=True)

In [1]:
import torch

In [None]:
def normalize(inputs, epsilon=1e-8, scope="ln", reuse=None):
    """Applies layer normalization.

    Args:
      inputs: A tensor with 2 or more dimensions, where the first dimension has
        `batch_size`.
      epsilon: A floating number. A very small number for preventing ZeroDivision Error.
      scope: Optional scope for `variable_scope`.
      reuse: Boolean, whether to reuse the weights of a previous layer
        by the same name.

    Returns:
      A tensor with the same shape and data dtype as `inputs`.
    """
    with tf.variable_scope(scope, reuse=reuse):
        inputs_shape = inputs.get_shape()
        params_shape = inputs_shape[-1:]

        mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
        beta = tf.Variable(tf.zeros(params_shape))
        gamma = tf.Variable(tf.ones(params_shape))
        normalized = (inputs - mean) / ((variance + epsilon) ** (0.5))
        outputs = gamma * normalized + beta

    return outputs

In [3]:
inputs = torch.randn((2, 4, 4))

In [19]:
epsilon=1e-8
inputs_shape = inputs.shape
params_shape = inputs_shape[-1:]

var, mean = torch.var_mean(inputs, -1, keepdim=True, unbiased=False)

In [20]:
beta = torch.zeros(params_shape)
gamma = torch.ones(params_shape)

In [21]:
normalized = (inputs - mean) / ((var + epsilon) ** (0.5))

In [22]:
normalized

tensor([[[ 0.1216,  0.9722,  0.5581, -1.6519],
         [-0.1728, -1.0248,  1.6483, -0.4507],
         [ 0.4595,  0.2377, -1.6695,  0.9723],
         [-1.6212,  0.0283,  0.5703,  1.0226]],

        [[ 0.1765, -1.1010,  1.5415, -0.6170],
         [-0.1468, -1.4077,  1.4052,  0.1493],
         [ 1.5495, -0.9658,  0.2061, -0.7898],
         [ 0.6270, -0.2344, -1.5144,  1.1218]]])

In [23]:
outputs = gamma * normalized + beta

In [24]:
outputs

tensor([[[ 0.1216,  0.9722,  0.5581, -1.6519],
         [-0.1728, -1.0248,  1.6483, -0.4507],
         [ 0.4595,  0.2377, -1.6695,  0.9723],
         [-1.6212,  0.0283,  0.5703,  1.0226]],

        [[ 0.1765, -1.1010,  1.5415, -0.6170],
         [-0.1468, -1.4077,  1.4052,  0.1493],
         [ 1.5495, -0.9658,  0.2061, -0.7898],
         [ 0.6270, -0.2344, -1.5144,  1.1218]]])

In [17]:
ln = torch.nn.LayerNorm(4, elementwise_affine=True)

In [18]:
ln.forward(inputs)

tensor([[[ 0.1216,  0.9722,  0.5581, -1.6519],
         [-0.1728, -1.0248,  1.6483, -0.4507],
         [ 0.4595,  0.2376, -1.6692,  0.9721],
         [-1.6212,  0.0283,  0.5703,  1.0226]],

        [[ 0.1765, -1.1009,  1.5414, -0.6170],
         [-0.1468, -1.4077,  1.4052,  0.1493],
         [ 1.5495, -0.9658,  0.2061, -0.7898],
         [ 0.6270, -0.2344, -1.5144,  1.1218]]])