In [12]:
import model.vanila_decoder as Decoder

In [25]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, context_length, num_heads, dropout, qkv_bias=False):
        super().__init__()
        
        assert d_model % num_heads == 0
        self.d_head = d_model // num_heads
        
        self.d_model = d_model
        self.num_heads = num_heads
        
        self.W_Q = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.W_K = nn.Linear(d_model, d_model, bias=qkv_bias)
        self.W_V = nn.Linear(d_model, d_model, bias=qkv_bias)
        
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.register_buffer("mask", torch.triu(torch.ones((context_length, context_length)), diagonal=1))
        
    def forward(self, x):
        
        batch_size, seq_len, d_model = x.shape
        
        queries = self.W_Q(x)
        keys = self.W_K(x)
        values = self.W_V(x)
        
        # splits to heads
        queries = queries.view(batch_size, seq_len, self.num_heads, self.d_head)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.d_head)
        values = values.view(batch_size, seq_len, self.num_heads, self.d_head)
        
        # exchange seq_len and num_head axis
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        attention_logits = queries @ keys.transpose(2, 3) / self.d_head ** 0.5
        
        mask_bool = self.mask.bool()[:seq_len,:seq_len]
        
        
        attention_logits.masked_fill(mask_bool, -torch.inf)
        
        attention_weights = torch.softmax(attention_logits, dim=-1)
        
        attention_weights = self.dropout(attention_weights)
        
        context_vec = attention_weights @ values
        
        
        context_vec = context_vec.transpose(2,3).contiguous().view(batch_size, seq_len, self.d_model)
        
        out = self.W_O(context_vec)
        
        return out

In [26]:
d_model = 512
context_length = 1024
num_heads = 8
dropout = 0.2
qkv_bias = False

attention = MultiHeadAttention(d_model, context_length, num_heads, dropout,qkv_bias)

In [27]:
x = torch.rand((12, 256, d_model))


In [28]:
context_x = attention(x)

In [29]:
context_x.shape

torch.Size([12, 256, 512])