In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2 # Batch, Time, Channels : Time component being the given x tokens predict next token which moves over time
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [3]:
# version 1 : nested for loop select previous and mean

In [4]:
# we want x[b, t] = mean_{i<=t} x[b, i]
xbow = torch.zeros(B, T, C) # bow = bag of words : term used when averaging up words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

In [5]:
xbow.shape

torch.Size([4, 8, 2])

In [6]:
# version 2 : lower triangular normalised over rows matrix multiply

# we can very very efficient about this using matrix multiplication
# lets look at a toy example

In [7]:
torch.manual_seed(42)
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
wei

xbow2 = wei @ x # (T, T) @ (B, T, C) => (B, T, T) @ (B, T, C) => batched matrix mult => (B, T, C)

In [8]:
torch.allclose(xbow, xbow2)

True

In [9]:
# version 3 : using softmax

In [10]:
tril = torch.tril(torch.ones(T, T))
wei  = torch.zeros(T, T) # torch.ones(T, T) : doesn't matter as long as all elments equal since they'll get softmaxed
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

xbow3 = wei @ x

In [11]:
torch.allclose(xbow3, xbow)

True

In [12]:
# Long story short
# You can do weighted aggeration of past elements
# by using matrix multiplication
# of a lower triangular fashion
# the elements of which are telling you how much of each elements fused/contributes to an individual position

In [13]:
# version 4 : self attention

In [14]:
# going to implement self attention for a small individual head

In [15]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # Using 32 channels 4 x 8 tokens, each token 32 dimensional
x = torch.randn(B, T, C)

In [16]:
tril = torch.tril(torch.ones(T, T))
wei  = torch.zeros(T, T) # torch.ones(T, T) : doesn't matter as long as all elments equal since they'll get softmaxed
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

out = wei @ x

In [17]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [18]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [19]:
# in version 3 when we initialise the affinities to be zeroes,
# then we see that wei gets uniform numbers, which softmax processes into a average

# we don't want this to be uniform tokens
# different tokens will find different other tokes more or less interesting, not uniformly interesting
# we want that to be data dependent

# this is the problem that self attention solves, information flow from the past to me in a data dependent way

In [20]:
# HEART OF THE THING

# Every single token at each position will emit 2 vectors (emit ?? : say bring along with them I guess)
# a Query vector and a Key vector (every token at every position)
# Query vector : "What am I looking for?"
# Key vector:    "What do I contain?"

# The way we get affinities between these tokens in a sequence
# we do a dot product between the Keys and the Queries
# for token N, Query(N) will dot product with all the Key(s) of (0..N-1)
# my query will dot product with all the keys of the other tokens (in the past)
# and that dot product now becomes the weight (as in the wei)
# instead of starting out normalised, the weights are deduced form Query Key dot product

# If the key and query are aligned, they'll interact with high amount
# and I will get to learn more about that specific token, as opposed to any other token in the sequence
# that's it really

In [21]:
# Implement single 'head' of self attention

In [22]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # Using 32 channels 4 x 8 tokens, each token 32 dimensional
x = torch.randn(B, T, C)

In [23]:
# single head of self attention
# Remember C is channels, dimension of tokens

head_size = 16 # attention hyper parm
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

In [24]:
# take note that the key and query production forward is in parallel

k = key(x)   # (B, T, head_size)
q = query(x) # (B, T, head_size)

# no communication has happened yet

In [25]:
k.shape, q.shape

(torch.Size([4, 8, 16]), torch.Size([4, 8, 16]))

In [26]:
# all queries now get dot producted with all the keys to produce the weights matrix

# but we need to align the shapes first

In [27]:
# we want to transpose the last two dimensions, i.e. the last two dimensions

In [28]:
k.transpose(-2, -1).shape

torch.Size([4, 16, 8])

In [29]:
# multid dim matmul happen on last two dims, need to align the keys and queries in those dims, keep the batch outside

In [30]:
wei = q @ k.transpose(-2, -1) # (B, T, hs) @ (B, hs, T) ----> (B, T, T)

In [31]:
wei[0]

tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5899, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [32]:
# for every row of b(atch), we will have a T-square matrix giving us the affinities and these are the weights

In [33]:
tril = torch.tril(torch.ones(T, T))
# wei  = torch.zeros(T, T) # torch.ones(T, T) # replaced by actual weights above
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

out = wei @ x

In [34]:
wei[0]

tensor([[0.0248, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0052, 0.0091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0521, 0.0135, 0.2482, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3171, 0.0214, 0.1642, 0.1188, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0412, 0.0487, 0.1046, 0.0742, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1060, 0.5347, 0.2059, 0.1030, 0.7402, 0.0192, 0.0000, 0.0000],
        [0.4298, 0.3409, 0.1769, 0.2027, 0.0480, 0.8472, 0.2329, 0.0000],
        [0.0238, 0.0316, 0.1002, 0.5013, 0.0117, 0.1336, 0.7671, 1.0000]],
       grad_fn=<SelectBackward0>)

In [35]:
# before wei was a constant value for each batch row, but now they're different

In [36]:
# say the 8th token, it know what content it has
# and it knows what the position it has (token emb + pos emb)

In [37]:
# so the token creates a query, "i'm looking for this kind of stuff"
# other token before it get to influence the interaction by their keys
# high affinity when the dot product operation gets a high return

In [38]:
# Ah aw, that's not it, we emit one more value per token, that we call 'value', similar to key and queries

In [39]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # Using 32 channels 4 x 8 tokens, each token 32 dimensional
x = torch.randn(B, T, C)

In [40]:
# single head of self attention
# Remember C is channels, dimension of tokens

head_size = 16 # attention hyper parm
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

value = nn.Linear(C, head_size, bias=False)

In [41]:
# and similarly we forward the value too, also in parallel
k = key(x)   # (B, T, head_size)
q = query(x) # (B, T, head_size)

v = value(x) # (B, T, head_size)

In [42]:
# we get the weights
wei = q @ k.transpose(-2, -1) # (B, T, hs) @ (B, hs, T) ----> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

In [43]:
wei.shape, v.shape, x.shape

(torch.Size([4, 8, 8]), torch.Size([4, 8, 16]), torch.Size([4, 8, 32]))

In [44]:
# but finally we don't just get the out directly from weights and x, instead we use the value

In [45]:
out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [46]:
# note that the output is head_size dimensional

In [47]:
# think of x as private information to the particular token, and v as the value communicated to other tokens

In [48]:
# query : here's what i'm interested in
# key   : here's what I have
# value : if you find we interesting(given the above affinities), here's what i'll communicate to you (you will have access to)

In [49]:
# query and keys act like filters which will define how much of value will be communicated after aggregation

In [50]:
# Notes

# Attention is a communication mechanism
# number of nodes in DAG
# every node has some vector of information (value)
# gets to aggregate vectors of information via weighted sum from all of the nodes that point to it
# aggregation done in a data dependent manner

# No notion of space
# spatial information has to be encoded, the nodes themselves don't know
# unlike convolutions

# Each example across batch dimension is processed independently and never talk to each other

# Not being able to attend to future tokens does not necessarily have to be the case
# In an encoder just remove the tril based masking, this model is effectively like a decoder hence the constraint
# Attention mechanism does not care

# Cross attention vs. Self attention
# this one is self attention because the keys queries and values all come from the same source (from input x)
# for eg. in Encoder Decoder Transformer
# queries ar produced from x
# keys and values can come from a whole different source(from the encoder block)
# the source has the context we want to condition our attention on
# When this happens its called cross attention
# separate source of information we'd like to pool attention from into our nodes
# self attention if we have nodes that'd like to look at each other and talk to each other

# From the equation of self attention we're missing one more thing, dividing by sq.root of the head size
# called as the scale attention
# an important normalisation, since we've ad hoc applied head size, we need to turn that back down
# in a sense we've let attention aggregate information over head_size dimensions
# if q and k are gaussian, they matmul, they're outputs will be in the order of head_size
# we need that to be guassian, so we normalise it

In [51]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
_w = q @ k.transpose(-2, -1)

_w.var()

tensor(17.4690)

In [52]:
# but if we normalise it

In [53]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
_w = q @ k.transpose(-2, -1) * (head_size ** -0.5)

_w.var()

tensor(0.9957)

In [54]:
# it's back to gaussian

In [55]:
# why is it important ?
# wei is fed into softmax, and it's important that wei be fairly diffused at init time
# if the values in wei are extreme, then the output of softmax can converge to be one hot vectors, not very useful

In [56]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5 ]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [57]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5 ])*9, dim=-1)

tensor([0.0228, 0.0015, 0.1382, 0.0015, 0.8359])

In [58]:
# softmax sharpens towards the max
# therefore we don't want these values to be too extreme, esp. at init, otherwise softmax will be way to picky
# every node will aggregate information only from other single node

In [59]:
# All the above mechanism goes in a single 'head' of attention

In [60]:
# single head of self attention
class Head(nn.Module):
    """ one head of self attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        # in Pytorch convention a variable that's not a parameter of the model is called a buffer
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x):
        B,T,C = x.shape
        # emit keys and queries for x
        k = self.key(x)  # (B, T, hs)
        q = self.query(x) # (B, T, hs)
        # compute attention
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B, T, hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

In [61]:
# Multi head attention
#
# applying multiple attention heads in parallel and the concatenating the results
# that's it really : the head sizes are matched so that the concatenation gives out the same dimensions

In [62]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

In [63]:
# Feedforward mechanism

# is just a MLP

In [64]:
class FeedForward(nn.Module):
    " simple linear layer followed by non linearity "
    
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, n_embed),
            nn.ReLU(),
        )
        
    def forward(self, x):
        return self.net(x)

In [65]:
# the tokens get an opportunity to use the multi headed attention, but then go right away into calculating logits
# looked at each other but didn't have enough time to think on what they found

In [66]:
# first we (multi) self attend - this is the communication
# then we feedforward - per token in parallel - independent of other tokens - now they need to think/react on that data individually
# basically provide the ability to decide how much to fire based on the communicated and collected information

In [67]:
# next, we start to intersperse tho communication with the computation

In [68]:
class Block(nn.Module):
    """ a transformer block : communication then computation """
    
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
    
    def forward(self, x):
        x = self.sa(x)   # communication
        x = self.ffwd(x) # computation
        return x

In [69]:
# After the addition of these blocks, the network is starting to get pretty deep
# but we are not yet getting good result
# so we will visit one more idea from the attention paper
# residual connections also called skip connections

In [70]:
# the basic idea is that
# there's the residual pathway (without computation)
# you're free to fork off the pathway and perform some more computation and project back to the pathway (via addition)
# remember addition distributes gradients equally to the branches during backprop
# gradient 'superhighway' all the way from supervision all the way to the input uninterrputed
# residual blocks contribute very little in the beginning given how they're init'd
# during learning they start to contribute more

In [71]:
class Block(nn.Module):
    """ a transformer block : communication then computation """
    
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
    
    def forward(self, x):
        x = x + self.sa(x)   # RESIDUAL ADD
        x = x + self.ffwd(x) # RESIDUAL ADD
        return x

In [72]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out) # outcome of the linear layer to project back into the residual pathway
        return out

In [73]:
class FeedForward(nn.Module):
    " simple linear layer followed by non linearity "
    
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), # as mentioned in the paper
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed) # projection layer : the final projection back into the residual pathway
        )
        
    def forward(self, x):
        return self.net(x)