In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [83]:
q_gram_1 = torch.rand(10,5,20)
q_gram_2 = torch.rand(10,5,20)
q_seq = torch.rand(10,5,20)
q_his = torch.rand(10,5,20)

query_sources = [q_gram_1, q_gram_2, q_seq]

In [255]:
class DotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, pooled_q_seq, q_his, batch_size):
        return torch.mean(torch.softmax(torch.bmm(pooled_q_seq, q_his.permute(0,2,1)), -1).view(batch_size,-1,1) * q_his,1, keepdim=True)

In [256]:
class MultiGrainSemanticUnit(nn.Module):

    def __init__(self):
        super().__init__()
        dp_attn = DotProductAttention()
        trm = nn.TransformerEncoderLayer(d_model=20, nhead=4, dim_feedforward=20)
    
    def forward(self, query_sources, q_his):
        """
        query_sources is expected to be [q_gram_1, q_gram_2, q_seq]
        """

        pooled_sources = []
        for query_source in query_sources:
            pooled_sources.append(torch.mean(query_source, dim=1).unsqueeze(1))

        
        q_seq_his = dp_attn(pooled_sources[-1], q_his, q_his.shape[0])
        q_seq_seq = torch.mean(trm(q_seq),1, keepdim=True)
        q_mix = q_seq_his + q_seq_seq + torch.sum(torch.cat(pooled_sources,1),1, keepdim=True)
        q_msg = torch.cat([*pooled_sources, q_seq_seq, q_seq_his, q_mix],2)

        return q_msg

mgs = MultiGrainSemanticUnit()

q_msg = mgs([q_gram_1,q_gram_2,q_seq],q_his)
        

In [233]:
user_real_time = torch.rand(10,3,20)
user_short_time = torch.rand(10,7,20)
user_long_time = torch.rand(10,14,20)


In [243]:
q_msg.shape

torch.Size([10, 6, 20])

In [264]:
class UserDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.dp_attn = DotProductAttention()
        
    def forward(self, q_msg, user_features):
        q_msg = q_msg.view(q_msg.shape[0],6,-1)

        h_reprs = []

        for q_repr_idx in range(q_msg.shape[1]):
            h_reprs.append(self.dp_attn(q_msg[:,q_repr_idx].unsqueeze(1), user_features, user_features.shape[0]))
        return torch.cat(h_reprs,2)

In [None]:
q_msg = q_msg.view(q_msg.shape[0],6,-1)

h_short_representations = []

for q_representation_idx in range(q_msg.shape[1]):
    h_short_representations.append(dp_attn(q_msg[:,q_representation_idx].unsqueeze(1), user_short_time, user_short_time.shape[0]))

torch.cat(h_short_representations,2)

In [None]:
class UserTower(nn.Module):

    def __init__(self):
        super().__init__()

        self.msgu = MultiGrainSemanticUnit()

        self.lstm = nn.LSTM(input=20, hidden_size=20, n_layers=1, batch_first=True)
        self.mh_attn = nn.MultiheadAttention(20, 4, batch_first=True)

        self.u_dp_attn = UserDotProductAttention()

    def forward(self, query_feautures, user_features):
        """
        query_feautures is expected to be  [[q_gram_1, q_gram_2, q_seq],q_his]
        user_features = is expected to be [real, short, long]-time features
        """

        assert query_feautures.shape[0] == 2 and query_feautures[0].shape[0] == 3

        q_msg = self.msgu(query_feautures[0], query_feautures[1])

        real, short, long = user_features

        self.real_time_part(q_msg, real)
        self.short_time_part(q_msg, short)
        self.long_time_part(q_msg, long)
        

    def real_time_part(self, q_msg, user_real_features):
        x = self.lstm(user_real_features)
        x= self.mh_attn(x)
        x = torch.cat([torch.zeros(x.shape[0], 1, 20),x], dim=1)
        return self.u_dp_attn(q_msg, x)

    def short_time_part(self, user_short_features):
        x= self.mh_attn(x)
        x = torch.cat([torch.zeros(x.shape[0], 1, 20),x], dim=1)
        return self.u_dp_attn(q_msg, x)

    def long_time_part(self, user_long_features):
        x= self.mh_attn(x)
        x = torch.cat([torch.zeros(x.shape[0], 1, 20),x], dim=1)
        return self.u_dp_attn(q_msg, x)

In [None]:
class ItemTower(nn.Module):

    def __init__(self):
        super().__init__()
        item_id_emb = nn.Embedding()
        title_emb = nn.EmbeddingBag()

    def forward(self, item_id, title):

        