### Single Head Attention

In [None]:
import torch 
import torch.nn as nn

In [None]:
# Defining parameters for the transformer model

n_embed = 120
n_layers = 8
n_heads = 8
head_size = n_embed // n_heads
block_size = 256  # Context size for the model
dropout = 0.2  # Dropout rate for regularization
vocab_size = 65_536   

15


We know that for each token it consist a embedding vector of dimension n_embed.

i.e.. Ei vector of size n_embed*1                ,where i runs to block_size

And for each head there is a query matrix and key matrix of size head_size*n_embed. 
Which is applied to same x for self-head attention

and Qi = Wq * Ei = head_size*1 for each block_size and batch_size

It can be represented as Linear(n_embed,head_size)

How much each query vector attends to key vector is represented from dot product of Ki.Qi at each cell of matrix of size TxT

this is represented by 
Attend = query @ key

and the x is represented with the down projection to the dimension of head_size which is concatenated later
Vi = Wv * Ei

output from single head = attend @ Vi

In [None]:
class SingleHeadAttention(nn.module):

    def _init__(self, n_embed, head_size):
        super().__init__()

        self.n_embed = n_embed
        self.head_size = head_size
        self.key = nn.Linear(n_embed, head_size)
        self.query = nn.Linear(n_embed, head_size)
        self.value = nn.Linear(n_embed, head_size)
 

        def forward(self, x):

            # x is a shape of Batch_size x Block_size x n_embed
            key= self.key(x)        # Batch_size x Block_size x head_size
            query = self.query(x)   # Batch_size x Block_size x head_size

            # Batch_size x Block_size x head_size @ Batch_size x head_size * Block_size 
            attend = query @ key.transpose(-2, -1)  # Batch_size x Block_size x Block_size

            attend = attend / (self.head_size ** 0.5)  #  Scaled Dot-Product Attention Attention(Q,K,V)=softmax(QK^T/sqrt(d_k))V

            trill = torch.tril(torch.ones(attend.shape[-1], attend.shape[-1]))  # Lower triangular matrix of block_size

            attend = attend.masked_fill(trill == 0, float('-inf'))  # Masking future tokens
            attend = torch.softmax(attend, dim=-1)

            value = self.value(x) # Batch_size x Block_size x head_size

            out = attend @ value  # Batch_size x Block_size x head_size

            return out


