In [2]:
import torch.nn as nn
import torch

- Sử dụng `CausalAttention` được implement ở phần trước.

In [7]:
class CausalAttention(nn.Module):
    def __init__(self, dim_in, dim_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.dim_out = dim_out
        self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(   # Đăng ký một tensor không phải là tham số của mô hình
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, dim_in = x.shape

        # 3 Q,K,V có shape (batch_size, context_length, dim_out)
        keys = self.W_key(x)    
        queries = self.W_query(x)
        values = self.W_value(x)

        print("query:")
        print(queries)

        # keys.transpose(1, 2) đổi shape từ (B, T, D) thành (B, D, T)
        attention_scores = queries @ keys.transpose(1, 2)  
        # [:num_tokens, :num_tokens] -- Slicing mask để phù hợp với số token hiện tại
        attention_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf  # Dùng được từ register_buffer 
        )
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1]**0.5, dim=-1
        )
        attention_weights = self.dropout(attention_weights)

        context_vectors = attention_weights @ values

        return context_vectors

- Wrapper class triển khai `multi-head attention`.

In [4]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, dim_in, dim_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(dim_in, dim_out, context_length, dropout) for _ in range(num_heads)]
        )

    def forward(self, x):
        # dim=-1 để concat theo chiều cuối cùng, do mỗi head output shape (batch_size, num_tokens, dim_out)
        # dim_out là chiều của mỗi context vector
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [8]:
torch.manual_seed(123)
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]] # step (x^6)
)
batch = torch.stack([inputs, inputs], dim=0)  # Tạo batch size = 2
context_length = batch.shape[1]   # Num_tokens
dim_in, dim_out = batch.shape[2], 2
multi_head_attn = MultiHeadAttentionWrapper(
    dim_in=dim_in,
    dim_out=dim_out,
    context_length=context_length,
    dropout=0.0,
    num_heads=2,
)
context_vectors = multi_head_attn(batch)

print(context_vectors)
print(f"context_vectors.shape: {context_vectors.shape} -- (batch_size, context_length / num_tokens, dim_out * num_heads)")

query:
tensor([[[-0.3536,  0.3965],
         [-0.3021, -0.0289],
         [-0.3015, -0.0232],
         [-0.1353, -0.0978],
         [-0.2052,  0.0870],
         [-0.1542, -0.1499]],

        [[-0.3536,  0.3965],
         [-0.3021, -0.0289],
         [-0.3015, -0.0232],
         [-0.1353, -0.0978],
         [-0.2052,  0.0870],
         [-0.1542, -0.1499]]], grad_fn=<UnsafeViewBackward0>)
query:
tensor([[[ 0.3326,  0.5659],
         [ 0.3558,  0.5643],
         [ 0.3412,  0.5522],
         [ 0.2123,  0.2991],
         [-0.0177,  0.1780],
         [ 0.3660,  0.4382]],

        [[ 0.3326,  0.5659],
         [ 0.3558,  0.5643],
         [ 0.3412,  0.5522],
         [ 0.2123,  0.2991],
         [-0.0177,  0.1780],
         [ 0.3660,  0.4382]]], grad_fn=<UnsafeViewBackward0>)
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5