In [1]:
import torch
from torch import nn
import torch.functional as F
import math

In [2]:
X = torch.randn(128, 64, 512)
print(X.shape)

torch.Size([128, 64, 512])


In [3]:
d_model=512
n_head = 8

In [9]:
class multi_head_attention(nn.Module):
    def __init__(self, d_model, n_head) -> None :
        super().__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        # after permute: [batch, n_head, time, n_d]
        q = q.view(batch, time, self.n_head, n_d).permute(0,2,1,3)
        k = k.view(batch, time, self.n_head, n_d).permute(0,2,1,3)
        v = v.view(batch, time, self.n_head, n_d).permute(0,2,1,3)

        score = q @ k.transpose(2,3)/math.sqrt(n_d)
        mask = torch.tril(torch.ones(time, time, dtype=bool))
        score = score.masked_fill(mask==0, float("-inf"))
        score = self.softmax(score) @ v

        score = score.permute(0,2,1,3).contiguous().view(batch, time, dimension)
        output = self.w_combine(score)

        return output


attention = multi_head_attention(d_model, n_head)
output = attention(X, X, X)
print(output, output.shape)

KeyboardInterrupt: 

In [None]:
attention = multi_head_attention(d_model, n_head)
output = attention(X, X, X)
print(output, output.shape)

tensor([[[-7.0129e-01, -5.7213e-01,  3.0061e-01,  ..., -1.7718e-01,
          -9.2418e-02,  7.8293e-02],
         [-2.7831e-01, -5.7390e-01, -1.2027e-01,  ...,  5.1195e-02,
          -1.9941e-01,  1.9650e-01],
         [-2.0199e-01, -3.9713e-01, -7.6728e-02,  ...,  9.7906e-02,
          -1.7641e-01,  3.4393e-01],
         ...,
         [ 8.0917e-03,  2.2210e-02,  3.7414e-02,  ...,  1.2549e-02,
          -7.5643e-02, -2.4918e-02],
         [ 3.9393e-02,  2.3743e-02,  5.3089e-02,  ...,  9.8831e-03,
          -4.2608e-02,  1.2276e-04],
         [ 2.3659e-02,  1.4782e-02,  4.4167e-02,  ...,  3.6256e-02,
          -4.3406e-02, -4.8235e-03]],

        [[-5.8673e-01,  1.3055e-01,  5.7295e-01,  ...,  5.2459e-01,
           1.2129e-01, -5.1063e-01],
         [-3.9818e-01,  2.4426e-01,  4.7797e-02,  ...,  2.3442e-01,
           3.0085e-01, -3.2239e-01],
         [-2.7068e-01,  2.2138e-01,  2.4531e-02,  ...,  1.8273e-01,
           1.8896e-01, -2.7297e-01],
         ...,
         [ 5.9622e-04,  2