# Next-token prediction

Not sure if this is how data is being prepared in actual training, but as a toy example, we can chop out different parts of a sentence, and then ask it to predict the next token.

In [1]:
# load inputs
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

text = text[:3000]

# get all the unique characters in the input
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Unique characters in the inputs:" + ''.join(chars))
print(f"voab size {vocab_size}")


Unique characters in the inputs:
 !',-.:;?ABCEFHILMNORSTUVWYabcdefghijklmnoprstuvwyz
voab size 52


In [2]:
# create a mapping from characters to integers
char_to_int = {char : i for i, char in enumerate(chars)}
int_to_char = {i: char for i, char in enumerate(chars)}

def encode(input_string):
    return [char_to_int[char] for char in input_string]

def decode(input_list):
    decoded_chars = [int_to_char[idx] for idx in input_list]
    return "".join(decoded_chars)

print(encode("hii there"))
print(decode(encode("hii there")))

[35, 36, 36, 1, 46, 35, 32, 44, 32]
hii there


In [3]:
# tokenise input
import torch
data = torch.tensor(encode(text), dtype = torch.long)

n = int(len(data) * 0.9)
train_data = data[:n]
val_data = data[n: ]

# set up context length
block_size = 8

x = train_data[:block_size]
y = train_data[1:block_size+1]

print(f"input sequence {x}")

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"for input {context}, target is {target}")

input sequence tensor([14, 36, 44, 45, 46,  1, 12, 36])
for input tensor([14]), target is 36
for input tensor([14, 36]), target is 44
for input tensor([14, 36, 44]), target is 45
for input tensor([14, 36, 44, 45]), target is 46
for input tensor([14, 36, 44, 45, 46]), target is 1
for input tensor([14, 36, 44, 45, 46,  1]), target is 12
for input tensor([14, 36, 44, 45, 46,  1, 12]), target is 36
for input tensor([14, 36, 44, 45, 46,  1, 12, 36]), target is 46


# Self-attention

## single-head attention

In [4]:
import torch.nn as nn
import torch.nn.functional as F

In [5]:
B,T,C = 4,6,12 # batch, time, channels
x = torch.randn(B,T,C)

In [6]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(wei)

v = value(x)
out = wei @ v

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.6346e-01, 7.3654e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [3.9511e-01, 2.9510e-01, 3.0979e-01, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [3.4207e-01, 2.9953e-02, 8.3407e-02, 5.4457e-01, 0.0000e+00,
          0.0000e+00],
         [1.8838e-01, 7.1301e-02, 2.1903e-02, 6.2837e-01, 9.0044e-02,
          0.0000e+00],
         [1.3378e-01, 2.0116e-01, 1.4047e-01, 1.2868e-01, 1.3532e-01,
          2.6059e-01]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [8.3133e-01, 1.6867e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.5852e-01, 4.2375e-01, 3.1773e-01, 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [4.1047e-01, 3.1722e-01, 2.4141e-01, 3.0894e-02, 0.0000e+00,
          0.0000e+00],
         [2.2129e-01, 7.2234e-01, 4.5479e-02, 7.5519e-03, 3.3372e-03

## fast vs slow implementation

```
# --------------------------------------------------------------------------
# Layout of the “big” QKV weight matrix expected by each implementation
#
#   ┌───────────── slow path (“hand-rolled” heads) ────────────────┐
#   │      Head 1     |      Head 2     |     …    |     Head H    │
#   │    Q1  K1  V1   |    Q2  K2  V2   |     …    |   QH  KH  VH  │
#   └──────────────────────────────────────────────────────────────┘

#
#   ┌──────── fast path (CausalSelfAttention.c_attn) ─────────┐
#   │       ALL Q       │        ALL K       │       ALL V    │
#   │    Q1 Q2 …  QH    |    K1 K2  …  KH    |  V1 V2  … VH   │     
#   └─────────────────────────────────────────────────────────┘
# --------------------------------------------------------------------------
```

In [None]:
import math
class SingleHeadCausalAttention(nn.Module):
    def __init__(self, config, keep_bias_term = True):
        super().__init__()
        
        self.dim = int(config.n_embd / config.n_head)
        
        self.q_proj = nn.Linear(config.n_embd, self.dim, bias = keep_bias_term)
        self.v_proj = nn.Linear(config.n_embd, self.dim, bias = keep_bias_term)
        self.k_proj = nn.Linear(config.n_embd, self.dim, bias = keep_bias_term)
        
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)))
    def forward(self, x):
        
        B, T, C = x.size()
        
        q = self.q_proj(x) # [B, T, hs]
        k = self.k_proj(x) # [B, T, hs]
        v = self.v_proj(x) # [B, T, hs]
        
        att =  q @ k.transpose(1, 2) * (1.0 / math.sqrt(self.dim))
            # [B, T, T] i-th row, the weights for token i
        att = att.masked_fill(self.bias[:T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim = -1) # softmax for each row
        
        z = att @ v # [B, T, hs]
        
        return z

class MultiHeadCausalAttention(nn.Module):
    
    def __init__(self, config, keep_bias_term = True):
        super().__init__()
        
        self.multi_head_attention = nn.ModuleList([SingleHeadCausalAttention(config,keep_bias_term) for _ in range(config.n_head)])
        
        self.multi_head_projection = nn.Linear(config.n_embd, config.n_embd, bias = keep_bias_term)
        
    def forward(self, x):
        
        concatenaed_multi_head = torch.cat([h(x) for h in self.multi_head_attention], dim = -1)
        projected_multi_head = self.multi_head_projection(concatenaed_multi_head)
        
        return projected_multi_head

In [None]:
from model import CausalSelfAttention, GPTConfig
import torch

slow_multi_head_attention = MultiHeadCausalAttention(GPTConfig, keep_bias_term = False)
fast_multi_head_attention = CausalSelfAttention(GPTConfig, keep_bias_term = False)


all_W_q = torch.cat([single_head.q_proj.weight.data for single_head in slow_multi_head_attention.multi_head_attention], dim=0)
all_W_k = torch.cat([single_head.k_proj.weight.data for single_head in slow_multi_head_attention.multi_head_attention], dim=0)
all_W_v = torch.cat([single_head.v_proj.weight.data for single_head in slow_multi_head_attention.multi_head_attention], dim=0)

concatenate_multihead_qkv = torch.cat([all_W_q, all_W_k, all_W_v], dim = 0)

fast_multi_head_attention.c_attn.weight.data = concatenate_multihead_qkv
fast_multi_head_attention.c_proj.weight.data = slow_multi_head_attention.multi_head_projection.weight.data


B, T, C = 2, 5, GPTConfig.n_embd
torch.manual_seed(42)
test_input = torch.rand((B, T, C))

slow_output = slow_multi_head_attention(test_input)
fast_output = fast_multi_head_attention(test_input)

torch.testing.assert_close(slow_output, fast_output)