In [2]:
import torch
from torch import nn
import torch.functional as F
import math

In [3]:
y =  torch.randn(2,3,4)
print(y)
print(y.shape)

tensor([[[-0.5244,  0.6244, -0.1362, -0.9387],
         [-0.6970, -1.4575, -1.3680, -0.3076],
         [-0.4164, -1.9435,  0.9441, -1.7237]],

        [[ 2.2761,  0.9115, -1.1156,  0.1597],
         [-0.9578,  0.1554, -0.5017,  0.2409],
         [-0.3129, -1.3928, -0.7515, -3.1130]]])
torch.Size([2, 3, 4])


In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        q = q.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3)

        score = q @ k.transpose(2,3) / math.sqrt(n_d)
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score) @ v
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, -1, self.d_model)
        output = self.w_combine(score)
        return output


