In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [34]:
class MutiHeadSelfAttention(nn.Module):
    def __init__(self, n_dim: int=768, n_head: int=8, attn_dropout: float=0.1):   
        super().__init__()
        self.n_dim = n_dim
        self.n_head = n_head
        self.c_attn = nn.Linear(n_dim, 3 * n_dim)
        self.c_dropout = nn.Dropout(attn_dropout)
        
        self.proj = nn.Linear(n_dim, n_dim)
        self.resid_dropout = nn.Dropout(attn_dropout)
    
    def forward(self, x, mask=False):
        B, T, C = x.shape
        
        q, k, v = self.c_attn(x).split(self.n_dim, 2)
        q = q.view(B, T, self.n_head, self.n_dim // self.n_head).transpose(1, 2) #(B, nh, T, nd)
        k = k.view(B, T, self.n_head, self.n_dim // self.n_head).transpose(1, 2) #(B, nh, T, nd)
        v = v.view(B, T, self.n_head, self.n_dim // self.n_head).transpose(1, 2) #(B, nh, T, nd)
        
        attn = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(k.shape[-1])))
        if mask:
            masked = torch.tril(torch.ones(T, T, dtype=bool))
            attn = attn.masked_fill(masked == 0, float('-inf'))
            print(attn)
        attn = F.softmax(attn, dim=-1)
        attn = self.c_dropout(attn)
        out = attn @ v #(B, nh, T, T) x (B, nh, T, nd) = (B, nh, T, nh)
        
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.resid_dropout(self.proj(out))
        return out
        

In [35]:
x = torch.randn((1, 3, 5))
attention = MutiHeadSelfAttention(n_head=1, n_dim=5)
out = attention(x, mask=True)
print(out.shape)

tensor([[[[0.0170,   -inf,   -inf],
          [0.2533, 0.1607,   -inf],
          [0.2029, 0.2582, 1.0119]]]], grad_fn=<MaskedFillBackward0>)
torch.Size([1, 3, 5])
