In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [19]:
torch.manual_seed(42)

B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B,T,C)


head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T, 16)
q = query(x) # (B,T,16)

wei = q @ k.transpose(-2, -1) # (B,T,16) @ (B, 16, T) --> (B,T,T)


tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v


out.shape


torch.Size([4, 8, 16])

In [18]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1905, 0.8095, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3742, 0.0568, 0.5690, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1288, 0.3380, 0.1376, 0.3956, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4311, 0.0841, 0.0582, 0.3049, 0.1217, 0.0000, 0.0000, 0.0000],
        [0.0537, 0.3205, 0.0694, 0.2404, 0.2568, 0.0592, 0.0000, 0.0000],
        [0.3396, 0.0149, 0.5165, 0.0180, 0.0658, 0.0080, 0.0373, 0.0000],
        [0.0165, 0.0375, 0.0144, 0.1120, 0.0332, 0.4069, 0.3136, 0.0660]],
       grad_fn=<SelectBackward0>)

In [20]:
import torch

batch_size = 2
seq_length = 4
d_model = 512
num_heads = 8
d_k = d_model // num_heads

# Random input tensor representing token embeddings
X = torch.randn(batch_size, seq_length, d_model)

print("Input tensor X shape:", X.shape)


Input tensor X shape: torch.Size([2, 4, 512])


In [22]:
query_layer = torch.nn.Linear(d_model, d_model)
key_layer = torch.nn.Linear(d_model, d_model)
value_layer = torch.nn.Linear(d_model, d_model)

# Project input to Q, K, V
Q = query_layer(X)  # Shape: (batch_size, seq_length, d_model)
K = key_layer(X)    # Shape: (batch_size, seq_length, d_model)
V = value_layer(X)  # Shape: (batch_size, seq_length, d_model)

print("Projected Q shape:", Q.shape)
print("Projected K shape:", K.shape)
print("Projected V shape:", V.shape)


Projected Q shape: torch.Size([2, 4, 512])
Projected K shape: torch.Size([2, 4, 512])
Projected V shape: torch.Size([2, 4, 512])


In [23]:
Q = Q.view(batch_size, seq_length, num_heads, d_k).transpose(1, 2)  # Shape: (batch_size, num_heads, seq_length, d_k)
K = K.view(batch_size, seq_length, num_heads, d_k).transpose(1, 2)  # Shape: (batch_size, num_heads, seq_length, d_k)
V = V.view(batch_size, seq_length, num_heads, d_k).transpose(1, 2)  # Shape: (batch_size, num_heads, seq_length, d_k)

print("Q split into heads shape:", Q.shape)
print("K split into heads shape:", K.shape)
print("V split into heads shape:", V.shape)

Q split into heads shape: torch.Size([2, 8, 4, 64])
K split into heads shape: torch.Size([2, 8, 4, 64])
V split into heads shape: torch.Size([2, 8, 4, 64])


In [28]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

torch.manual_seed(42)

B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B,T,C)



class MultiHeadAttentioBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float) -> None:
        super().__init__()
        # Embedding vector size. Each token in the input sequence is mapped 
        # to an embedding vector of the size
        self.d_model = d_model 
        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model is not divisible by h"

        self.d_k = d_model // num_heads # dimension of vector seen by each head

        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv

        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wv


        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]

        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)


        if mask is not None:
            # small value to indicate not to tend to those positions
            attention_scores.masked_fil_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)            

        return (attention_scores @ value)

    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)

        # (batch. seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.num_heads, self.d_k).transpose(1,2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads, self.d_k).transpose(1,2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads, self.d_k).transpose(1,2)

        x = MultiHeadAttentioBlock.attention(query, key, value, mask, self.dropout)


        # Combine the heads together
        
        # (batch, h, seq_len, d_k) --> () --> (batch, seq_len,d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        return self.w_o(x)

In [30]:
# Test the MultiHeadAttentionBlock
d_model = 32
num_heads = 4
dropout = 0.1

# Create a MultiHeadAttentionBlock instance
mha_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)

# Random input tensor
q = k = v = x  # Using the same tensor for query, key, and value

# Forward pass
output = mha_block(q, k, v)

print("Output shape:", output.shape)

NameError: name 'MultiHeadAttentionBlock' is not defined