In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
torch.set_printoptions(sci_mode=False, precision=4)

%matplotlib inline

In [5]:
def karpathy(c_attn, x, n_head, n_embd, flash=True):
    B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

    # calculate query, key, values for all heads in batch and move head forward to be the batch dim
    q, k, v  = c_attn(x).split(n_embd, dim=2)
    k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
    q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)

    bias = torch.tril(torch.ones(T, T)).view(1, 1, T, T)        

    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
    # efficient attention using Flash Attention CUDA kernels
    if flash:
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True)    
    else:
        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(bias == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)        
    y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
    return y

In [10]:
n_embd = 8
n_head = 2
flash = True
device = 'cuda' if flash else 'cpu'
torch.manual_seed(13)
c_attn = nn.Linear(n_embd, 3 * n_embd, bias=False).to(device)
flash = True
x = torch.randn(2, 4, n_embd).to(device)
karpathy(c_attn, x, n_head, n_embd, flash)

(tensor([[[    -0.1038,     -0.1869,     -0.9511,     -0.4516,      0.7749,
               -0.8839,     -0.5570,     -0.7959],
          [    -0.1598,     -0.5558,     -0.0879,      0.0378,     -0.9518,
                0.7826,     -0.0118,     -0.2513],
          [    -0.1021,     -0.4652,     -0.1973,      0.0524,     -0.0011,
               -0.0733,     -0.2401,     -0.3144],
          [     0.1038,     -0.3479,     -0.1774,     -0.1089,     -0.4437,
                0.1967,     -0.0915,     -0.2301]],
 
         [[    -0.5489,     -0.8790,      1.0406,      0.5399,      1.1409,
               -0.1098,      0.6044,      1.5876],
          [    -0.5449,     -0.5713,      0.9675,      0.5635,     -0.0129,
                0.4037,      0.4179,      0.9543],
          [    -0.2615,     -0.1719,      0.4718,      0.2990,     -0.2377,
                0.4195,      0.1971,      0.4427],
          [    -0.4090,      0.0866,      0.3747,      0.1671,     -0.4818,
               -0.0093,     -0.0