## How does Self Attention Work? 

In [14]:
# Initiate DIMS Beforehand for ease. 
B = 16
T = 100
D = 512
H = 8
x = torch.rand(B,T,D) # batch size, num tokens, dim 

In [15]:
keys = nn.Linear(D,D)
query = nn.Linear(D,D)
value = nn.Linear(D,D)

In [29]:
k = keys(x)
q = query(x)
v = value(x)

k = k.view(B,T,H,D//H).transpose(1,2)
q = q.view(B,T,H,D//H).transpose(1,2)
v = v.view(B,T,H,D//H).transpose(1,2)

In [30]:
k.shape

torch.Size([16, 8, 100, 64])

In [31]:
# Diff ways to dot product
print(torch.matmul(q, k.transpose(2, 3)).shape)
print((q @ k.transpose(-2, -1)).shape)

torch.Size([16, 8, 100, 100])
torch.Size([16, 8, 100, 100])


In [32]:
scores = torch.matmul(q, k.transpose(-2, -1))
scores.shape
# Apply mask now if needed. 

torch.Size([16, 8, 100, 100])

In [33]:
weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)

In [4]:
import math
import logging

import torch
import torch.nn as nn
from torch.nn import functional as F


class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.resid_drop = nn.Dropout(config.resid_pdrop)
        # output projection
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))
        self.n_head = config.n_head

    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y
