In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
q_gram_1 = torch.rand(10,5,20)
q_gram_2 = torch.rand(10,5,20)
q_seq = torch.rand(10,5,20)
q_his = torch.rand(10,5,20)

query_sources = [q_gram_1, q_gram_2, q_seq]

In [9]:
class MultiGrainSemanticUnit(nn.Module):

    def __init__(self, hiddne_dim=20):
        super().__init__()
        self.dp_attn = DotProductAttention()
        self.trm = nn.TransformerEncoderLayer(d_model=hiddne_dim, nhead=4, dim_feedforward=hiddne_dim)
    
    def forward(self, query_sources, q_his):
        """
        query_sources is expected to be [q_gram_1, q_gram_2, q_seq]
        """

        pooled_sources = []
        for query_source in query_sources:
            pooled_sources.append(torch.mean(query_source, dim=1).unsqueeze(1))

        
        q_seq_his = self.dp_attn(pooled_sources[-1], q_his, q_his.shape[0])
        q_seq_seq = torch.mean(self.trm(q_seq),1, keepdim=True)
        q_mix = q_seq_his + q_seq_seq + torch.sum(torch.cat(pooled_sources,1),1, keepdim=True)
        q_msg = torch.cat([*pooled_sources, q_seq_seq, q_seq_his, q_mix],2)

        return q_msg

mgs = MultiGrainSemanticUnit()

q_msg = mgs([q_gram_1,q_gram_2,q_seq],q_his)
        

In [10]:
class UserDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.dp_attn = DotProductAttention()
        
    def forward(self, q_msg, user_features):
        q_msg = q_msg.view(q_msg.shape[0],6,-1)

        h_reprs = []

        for q_repr_idx in range(q_msg.shape[1]):
            h_reprs.append(self.dp_attn(q_msg[:,q_repr_idx].unsqueeze(1), user_features, user_features.shape[0]))
        return torch.cat(h_reprs,2)

In [12]:
from typing import List, Dict, Optional, Callable

class MLP(torch.nn.Sequential):
    """This block implements the multi-layer perceptron (MLP) module.
    Args:
        in_channels (int): Number of channels of the input
        hidden_channels (List[int]): List of the hidden channel dimensions
        norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
        activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
        inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
        bias (bool): Whether to use bias in the linear layer. Default ``True``
        dropout (float): The probability for the dropout layer. Default: 0.0

    https://github.com/pytorch/vision/blob/ce257ef78b9da0430a47d387b8e6b175ebaf94ce/torchvision/ops/misc.py#L263
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: List[int],
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        inplace: Optional[bool] = True,
        bias: bool = True,
        dropout: float = 0.0,
    ):
        params = {} if inplace is None else {"inplace": inplace}

        layers = []
        in_dim = in_channels
        for hidden_dim in hidden_channels[:-1]:
            layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
            if norm_layer is not None:
                layers.append(norm_layer(hidden_dim))
            layers.append(activation_layer(**params))
            layers.append(torch.nn.Dropout(dropout, **params))
            in_dim = hidden_dim

        layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
        layers.append(torch.nn.Dropout(dropout, **params))

        super().__init__(*layers)

In [13]:
class UserTower(nn.Module):

    def __init__(self,hidden_dim = 20):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.msgu = MultiGrainSemanticUnit()

        self.lstm = nn.LSTM(input_size=hidden_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True)
        self.mh_attn = nn.MultiheadAttention(hidden_dim, num_heads=4, batch_first=True)

        self.u_dp_attn = UserDotProductAttention()

        self.cls_token = nn.Embedding(1,20)

    def forward(self, query_feautures, user_features):
        """
        query_feautures is expected to be  [[q_gram_1, q_gram_2, q_seq],q_his]
        user_features = is expected to be [real, short, long]-time features
        """

        assert len(query_feautures) == 2 and len(query_feautures[0]) == 3

        q_msg = self.msgu(query_feautures[0], query_feautures[1])

        real, short, long = user_features

        real_part = self.real_time_part(q_msg, real)
        short_part = self.short_time_part(q_msg, short)
        long_part = self.long_time_part(q_msg, long)

        cls_token = self.cls_token(torch.tensor([[0]])).expand(long_part.shape[0],1, -1)

        sequence = torch.cat([cls_token, q_msg.view(-1,6, self.hidden_dim), real_part, short_part, long_part], 1)

        x, _ = self.mh_attn(sequence,sequence,sequence)

        return x[:,0]
        

    def real_time_part(self, q_msg, user_real_features):
        x, (hn, cn) = self.lstm(user_real_time)

        x, _ = self.mh_attn(x, x, x)
        x = torch.cat([torch.zeros(x.shape[0], 1, self.hidden_dim),x], dim=1)
        return torch.cat(self.u_dp_attn(q_msg, x),1)

    def short_time_part(self,q_msg, user_short_features):
        x, _ = self.mh_attn(user_short_features, user_short_features, user_short_features)
        x = torch.cat([torch.zeros(x.shape[0], 1, self.hidden_dim),x], dim=1)
        return torch.cat(self.u_dp_attn(q_msg, x),1)

    def long_time_part(self, q_msg, user_long_features):
        """
        user_long_features is expected to be [4x[clicked, bought, collected]]
        """
        h_attributes=[]
        for attribute_index in range(user_long_features.shape[1]):
            x = torch.mean(user_long_features[:,attribute_index], dim = 2)
            x = torch.cat([torch.zeros(x.shape[0], 1, self.hidden_dim),x], dim=1)
            h_attributes.append(torch.cat(self.u_dp_attn(q_msg, x),1))
        return torch.cat(h_attributes,1)

In [14]:
user_real_time = torch.rand(10,3,20)  # [bs, items, dim]
user_short_time = torch.rand(10,7,20)  # [bs, items, dim]
user_long_time = torch.rand(10,4,3,5,20)  # [bs, attributes, interactions, items, dim]

user_tower = UserTower()

In [15]:
user_emb = user_tower([query_sources, q_his], [user_real_time, user_short_time, user_long_time])

TypeError: cat() received an invalid combination of arguments - got (Tensor, int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


In [16]:
class ItemTower(nn.Module):

    def __init__(self, num_item_emb, num_title_emb, hidden_dim = 20):
        super().__init__()
        self.item_id_emb = nn.Embedding(num_item_emb, hidden_dim)
        self.title_emb = nn.EmbeddingBag(num_title_emb, hidden_dim)

        self.mlp = MLP(hidden_dim, [hidden_dim])

    def forward(self, item_index, title_indexes):

        item_e = self.item_id_emb(item_index.flatten())
        title_e = self.title_emb(title_indexes)

        return item_e + torch.tanh(self.mlp(title_e))

In [17]:
item_indices = torch.randint(0,100, (10,1))
titles_indices = torch.randint(0,100, (10,5))

In [19]:
item_tower = ItemTower(100,100)

In [20]:
item_emb = item_tower(item_indices,titles_indices)

In [549]:
import matplotlib.pyplot as plt

plt.imshow(user_emb @ item_emb.T)

RuntimeError: Numpy is not available