In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your(x^1)
[0.55, 0.87, 0.66], # journey(x^2)
[0.57, 0.85, 0.64], # starts(x^3)
[0.22, 0.58, 0.33], # with(x^4)
[0.77, 0.25, 0.10], # one(x^5)
[0.05, 0.80, 0.55]] # step(x^6)
)

In [3]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in: int, d_out: int, context_length: int, dropout: int, num_heads: int, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0 # to strictly check number of output dimentins should be divisible by heads for their equal distribution in all the heads without floating pionts that is not acdeptable.
        self.d_out = d_out
        self.head_dim = d_out // num_heads # to divide output into all heads.
        self.dropout = nn.Dropout(dropout)
        self.context_length = context_length
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key= nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # to save the shape dimentins.
        self.num_heads = num_heads

        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )


    def forward(self, x):
        # here x is our input torch tensor batch(embedding of a sentence tokens):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # dividing each matrics into small metrices for splitting into heads:
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # taking transpose to convert [Batch, Tokens, Heads, Dims] into [Batch, Heads, Tokens, Dims] shape:
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores = attn_scores.masked_fill_(mask_bool, -torch.inf)

        # normalizing each and taking softmax to convert into attn weights:
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        # droping out:
        attn_weights = self.dropout(attn_weights)

        # transpose back to [batch, tokens, heads, dims]:
        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # .contiguous() rearranges it in memory for efficiency before the next .view(). and .view() combine all the small metrices into al large matric back. 2, 6, 4 .
        context_vec = self.out_proj(context_vec) # final check and if needed convert to original shape.
        return context_vec

In [5]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],
     [0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64],
     [0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10],
     [0.05, 0.80, 0.55]]
)

# creating a batch of 2 sequences
batch = torch.stack((inputs, inputs), dim=0)

In [6]:
d_in = 3
d_out = 4
context_length = 6
dropout = 0.1
num_heads = 4

In [7]:
torch.manual_seed(123)
mha = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads)
context_vecs = mha(batch)

print("Input shape:", batch.shape)
print("Output shape:", context_vecs.shape)
print("\nOutput values:")
print(context_vecs)

Input shape: torch.Size([2, 6, 3])
Output shape: torch.Size([2, 6, 4])

Output values:
tensor([[[ 0.1184,  0.3252, -0.0870, -0.5899],
         [ 0.0076,  0.3391, -0.0776, -0.4197],
         [-0.0287,  0.3427, -0.0744, -0.3635],
         [-0.0438,  0.3239, -0.0731, -0.3321],
         [-0.0318,  0.2762, -0.0816, -0.3648],
         [-0.0538,  0.3060, -0.0716, -0.3083]],

        [[ 0.1184,  0.3252, -0.0870, -0.5899],
         [-0.1006,  0.3153, -0.0877, -0.3057],
         [-0.0287,  0.3427, -0.0744, -0.3635],
         [-0.0388,  0.2975, -0.0805, -0.3573],
         [ 0.0263,  0.3174, -0.0555, -0.3554],
         [-0.0468,  0.2723, -0.0740, -0.3175]]], grad_fn=<ViewBackward0>)
