In [1]:
import torch 
import torch.nn as nn 

In [44]:
class MergeEmbedWithConv(nn.Module):
    def __init__(self, emb_dim):
        super(MergeEmbedWithConv, self).__init__()
        self.emb_dim = emb_dim
        self.mask_conv = nn.Sequential(
            nn.Conv1d(emb_dim*2, emb_dim, 1),
            # nn.BatchNorm1d(emb_dim),
            nn.Sigmoid(),
        )
    
    def forward(self, emb1, emb2):
        """
        input shape:
            emb1: (batch, emb_dim)
            emb2: (batch, 1, emb_dim, T)
        """
        emb2 = emb2.squeeze(1).permute(0, 2, 1)
        emb1 = emb1.unsqueeze(1)
        B, T, emb_dim = emb2.size()
        emb1 = emb1.repeat(1, T, 1)
        # print("shape : ", emb1.shape, emb2.shape)
        # shape :  torch.Size([6, 2999, 256]) torch.Size([6, 2999, 256])
        # concat emb1 and emb2 to get (batch, T, emb_dim, 2)
        emb_concate = torch.cat((emb1, emb2), dim=-1)
        # print("EMB concate shape: ", emb_concate.shape)#  EMB concate shape:  torch.Size([6, 2999, 512])
        # emb_concate = emb_concate.permute(0, 2, 1)
        emb2_mask = self.mask_conv(emb_concate.permute(0, 2, 1)).permute(0, 2, 1)
        emb_out = emb2_mask * emb2 + emb1
        return emb_out

In [45]:
model_dim = 60
model = MergeEmbedWithConv(model_dim)
emb1 = 100*torch.randn(2, model_dim)
emb2 = 100*torch.randn(2, 1, model_dim, 50)
model = model.cuda()
emb1 = emb1.cuda()
emb2 = emb2.cuda()
print("emb1 range: ", emb1.min(), emb1.max())
model.eval()

emb1 range:  tensor(-225.9766, device='cuda:0') tensor(221.0042, device='cuda:0')


MergeEmbedWithConv(
  (mask_conv): Sequential(
    (0): Conv1d(120, 60, kernel_size=(1,), stride=(1,))
    (1): Sigmoid()
  )
)

In [46]:
out1 = model(emb1, emb2)
print(out1.shape) # torch.Size([2, 5, 10])

torch.Size([2, 50, 60])


In [47]:
out2 = model(emb1, emb2[..., -1:])
print(out2.shape)

torch.Size([2, 1, 60])


In [48]:
# print(out1[:, -1, :])
# print(out2[:, 0, :])
print(out1[:, -1, :] - out2[:, 0, :])

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.3499e-02,
          0.0000e+00,  7.8201e-05,  0.0000e+00,  0.0000e+00, -6.7062e-02,
          0.0000e+00,  1.4481e-02,  0.0000e+00,  3.4332e-05,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  4.6349e-04,
          2.4414e-04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  3.1311e-02,
         -4.5776e-05,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.9836e-04,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  1.3159e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.9998e-01,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -1.5259e-04,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  7