In [5]:
# version 4: self-attention!
import torch 
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
# Lets look at a single head attention, will come to multi-head later
# What is a query?:  Every token "asks" for the infomation it wants (eg: search for a youtube video)
query = nn.Linear(C, head_size, bias=False)
# What is a key?: Every token matches the query with its keys (eg: match the search query to video title and video description from the youtube database)
key = nn.Linear(C, head_size, bias=False)
# What is a value?: Return the information that is asked for (eg: retrive the best matched videos)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [8]:
# Different attention for different batch dimension
# Different weights for different components of history
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5877, 0.4123, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4457, 0.2810, 0.2733, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2220, 0.7496, 0.0175, 0.0109, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0379, 0.0124, 0.0412, 0.0630, 0.8454, 0.0000, 0.0000, 0.0000],
        [0.5497, 0.2187, 0.0185, 0.0239, 0.1831, 0.0062, 0.0000, 0.0000],
        [0.2576, 0.0830, 0.0946, 0.0241, 0.1273, 0.3627, 0.0507, 0.0000],
        [0.0499, 0.1052, 0.0302, 0.0281, 0.1980, 0.2657, 0.1755, 0.1474]],
       grad_fn=<SelectBackward0>)

In [9]:
wei[1]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4289, 0.5711, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5413, 0.1423, 0.3165, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0635, 0.8138, 0.0557, 0.0669, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4958, 0.0758, 0.2224, 0.0156, 0.1905, 0.0000, 0.0000, 0.0000],
        [0.3957, 0.1127, 0.3724, 0.0024, 0.1128, 0.0040, 0.0000, 0.0000],
        [0.0229, 0.5252, 0.0084, 0.0047, 0.2768, 0.0983, 0.0637, 0.0000],
        [0.0021, 0.0327, 0.0042, 0.0821, 0.0244, 0.8253, 0.0154, 0.0139]],
       grad_fn=<SelectBackward0>)

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [10]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [11]:
k.var()

tensor(1.0449)

In [12]:
q.var()

tensor(1.0700)

In [13]:
wei.var()

tensor(1.0918)

In [29]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) 

In [30]:
k.var()

tensor(1.0632)

In [31]:
q.var()

tensor(0.9891)

In [32]:
wei.var() # Roughly needs to be divided by 4 for ~ § variance

tensor(15.6088)

In [14]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [16]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot, initialization scheme for NAS

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

In [None]:
# Why this is a bad starting point?
# We do not want the attention mechanism to have a bias towards some tokens in the 
# history, especially not when the model has not even started training yet

In [44]:
# Multi-Headed self attention MSA
# Intuitively: We may have different things/concepts which may be orthogonal or complementary that we want to know about
n_embd = 32 # Input embedding dim 
block_size = 8
head_size = 8
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out
    
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        # Simply stack multiple heads 
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [46]:
x = torch.randn(32,block_size, n_embd) 
msa = MultiHeadAttention(4,head_size)
x = msa(x)
print(out.shape)

torch.Size([32, 8, 32])


In [51]:
# Placed after the MSA layer for computation/processing
# Projection to higher dimension to a lower dimension -> expansion ratio, often a hyperparameter
# Input B,T,C -> B,T,4*C -> B,T,C
import torch.nn as nn
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)

In [53]:
ffn = FeedFoward(n_embd)
out = ffn(x)
print(out.shape)

torch.Size([32, 8, 32])
