In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [12]:
# Our silly language
# W 0 W W W W W

def generate(L=64, max_start=10):
    start_len = np.random.randint(2, max_start)
    start = list(np.random.randint(1, 4, size=start_len))
    out = start + [0]
    while len(out) < L:
        out += start
    return out[:L]
    
generate()

[2,
 1,
 3,
 2,
 3,
 0,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3,
 2,
 3,
 2,
 1,
 3]

In [15]:
DEBUG = True

class SelfAttentionWithoutMask(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        
        self.q_proj = nn.Linear(embed_size, embed_size, bias=False)
        self.k_proj = nn.Linear(embed_size, embed_size, bias=False)
        self.v_proj = nn.Linear(embed_size, embed_size, bias=False)
        self.o_proj = nn.Linear(embed_size, embed_size, bias=False)
        
        self.num_heads = num_heads
        self.head_size = embed_size // num_heads
        
    def forward(self, x):
        B, L, E = x.shape
        q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k_proj(x).reshape(B, L, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v_proj(x).reshape(B, L, self.num_heads, self.head_size).transpose(1, 2)
        
        if DEBUG:
            print("input shape", x.shape)
            print("q shape", q.shape)
        
            
        # This is what you want to use in practice
        attn_out = F.scaled_dot_product_attention(q, k, v)
        if DEBUG:
            print("attn out shape", attn_out.shape)
            
            # Here we explicitly calculate attention
            
            # First we calculate attention logits, these two formulas are equivalent
            attn_logits = torch.einsum("bhqe, bhke -> bhqk", q, k) / np.sqrt(self.head_size)
            attn_logits2 = q.matmul(k.transpose(-2,-1)) / np.sqrt(self.head_size)
            print("attn logits shapes", attn_logits.shape, attn_logits2.shape)
            print("max diff", (attn_logits - attn_logits2).abs().amax().item())
            
            attn_matrix = torch.softmax(attn_logits, dim=-1)
            
            attn_out2 = torch.einsum("bhqk, bhke -> bhqe", attn_matrix, v)
            print("attn out shapes", attn_out2.shape, attn_out.shape)
            
            print("attn diff", (attn_out - attn_out2).square().sum() / attn_out.square().sum())
            
            
        reshaped = attn_out.transpose(1,2).reshape(B, L, E)
        return self.o_proj(reshaped)
            
        
            
sa = SelfAttentionWithoutMask(32, 4)

sa(torch.rand(10, 20, 32)).shape

input shape torch.Size([10, 20, 32])
q shape torch.Size([10, 4, 20, 8])
attn out shape torch.Size([10, 4, 20, 8])
attn logits shapes torch.Size([10, 4, 20, 20]) torch.Size([10, 4, 20, 20])
max diff 0.0
attn out shapes torch.Size([10, 4, 20, 8]) torch.Size([10, 4, 20, 8])
attn diff tensor(1.5698e-14, grad_fn=<DivBackward0>)


torch.Size([10, 20, 32])

In [4]:
# We will use causal attention, i.e. each sequence step sees only steps before it

DEBUG = True

class CausalSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        
        self.q_proj = nn.Linear(embed_size, embed_size, bias=False)
        self.k_proj = nn.Linear(embed_size, embed_size, bias=False)
        self.v_proj = nn.Linear(embed_size, embed_size, bias=False)
        self.o_proj = nn.Linear(embed_size, embed_size, bias=False)
        
        self.num_heads = num_heads
        self.head_size = embed_size // num_heads
        
    def forward(self, x):
        B, L, E = x.shape
        q = self.q_proj(x).reshape(B, L, self.num_heads, self.head_size).transpose(1, 2)
        k = self.k_proj(x).reshape(B, L, self.num_heads, self.head_size).transpose(1, 2)
        v = self.v_proj(x).reshape(B, L, self.num_heads, self.head_size).transpose(1, 2)
        
        # This is what you want to use in practice
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        reshaped = attn_out.transpose(1,2).reshape(B, L, E)
        return self.o_proj(reshaped)
            
        
            
sa = CausalSelfAttention(32, 4)

sa(torch.rand(10, 20, 32)).shape

torch.Size([10, 20, 32])

In [5]:
# Lets have some input and change 4th element
inp = torch.rand(1, 10, 32)
out1 = sa(inp)

inp[:,3] += 0.1

out2 = sa(inp)

# We only see changes after 4th element, first 3 are not affected
(out1 - out2).square().sum(dim=2)

tensor([[0.0000, 0.0000, 0.0000, 0.0025, 0.0017, 0.0012, 0.0009, 0.0007, 0.0005,
         0.0004]], grad_fn=<SumBackward1>)

In [20]:
# Let's build the rest of the model

class Block(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.self_attn = CausalSelfAttention(embed_size, num_heads)
        
        # We also have an MLP block, it's contents differ, LLaMA uses complicated structure, we will
        # keep the simple one
        self.mlp = nn.Sequential(
            nn.Linear(embed_size, 4*embed_size),
            nn.SiLU(),
            nn.Linear(4*embed_size, embed_size)
        )
        
    def forward(self, x):
        # Why do we do self addition and normalization? Next lecture!
        x = x + self.self_attn(self.norm1(x))
        x = x + self.mlp(self.norm1(x))
        return x
        
class Model(nn.Module):
    def __init__(self, n_blocks, embed_size, num_heads, vocab_size, max_poses):
        super().__init__()
        
        self.blocks = nn.Sequential(
            *[Block(embed_size, num_heads) for _ in range(n_blocks)]
        )
        
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_embed = nn.Parameter(torch.randn(max_poses, embed_size))
        self.out_norm = nn.LayerNorm(embed_size)
        self.out = nn.Linear(embed_size, vocab_size)
        
    def forward(self, sequences):
        embeded = self.embed(sequences)
        #embeded = embeded + self.pos_embed[:embeded.shape[1]]
        
        if DEBUG:
            print("embeded shape", embeded.shape)
            
        embeded = self.blocks(embeded)
        
        embeded = self.out_norm(embeded)
        return self.out(embeded)
    
model = Model(2, 32, 2, 4, 64)

# For each sequence position we got probability logits (inputs for softmax) for the next token
model(torch.randint(4, size=(2,16))).shape

torch.Size([2, 16, 4])

In [21]:
DEBUG = False

# How to sample (suboptimal implementation)

seq = [3]
for i in range(10):
    out = model(torch.LongTensor([seq]))[0,-1]
    # Convert to probs
    out_p = torch.softmax(out, dim=-1)
    # Sample next token
    next_token = torch.multinomial(out_p, num_samples=1).item()
    seq.append(next_token)
    print(seq)

[3, 1]
[3, 1, 0]
[3, 1, 0, 3]
[3, 1, 0, 3, 1]
[3, 1, 0, 3, 1, 0]
[3, 1, 0, 3, 1, 0, 3]
[3, 1, 0, 3, 1, 0, 3, 3]
[3, 1, 0, 3, 1, 0, 3, 3, 3]
[3, 1, 0, 3, 1, 0, 3, 3, 3, 3]
[3, 1, 0, 3, 1, 0, 3, 3, 3, 3, 1]


In [22]:
# Let's train
# Notice that we get train all positions at once
# E.g.: Output from step 3 is only influence by first 3 steps and not by step 4

model.cuda()
opt = torch.optim.AdamW(model.parameters(), 1e-4, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

for step in range(10000):
    data = torch.LongTensor([generate() for _ in range(16)]).cuda()
    out = model(data)
    out_pred = out.argmax(dim=-1)
    # Notice the shift
    # Flatten is because of pytorch quirks
    loss = loss_fn(out[:,:-1].flatten(0, -2), data[:,1:].flatten())
    
    if step % 100 == 0:
        print(step, loss.item(), "acc", (out_pred[:,:-1] == data[:,1:]).sum().item() / data[:,1:].numel(),
              "acc after step 10", (out_pred[:,9:-1] == data[:,10:]).sum().item() / data[:,10:].numel())
    
    loss.backward()
    opt.step()
    opt.zero_grad()
    

0 1.3154503107070923 acc 0.4087301587301587 acc after step 10 0.41782407407407407
100 0.926128625869751 acc 0.6805555555555556 acc after step 10 0.7118055555555556
200 0.9247331023216248 acc 0.6190476190476191 acc after step 10 0.6493055555555556
300 0.9681235551834106 acc 0.5 acc after step 10 0.5185185185185185
400 0.9258744120597839 acc 0.5932539682539683 acc after step 10 0.6203703703703703
500 0.8625814318656921 acc 0.6656746031746031 acc after step 10 0.6921296296296297
600 0.9161351919174194 acc 0.5486111111111112 acc after step 10 0.5706018518518519
700 0.8958262801170349 acc 0.6111111111111112 acc after step 10 0.6400462962962963
800 0.9110379815101624 acc 0.5565476190476191 acc after step 10 0.5729166666666666
900 0.9073696732521057 acc 0.6448412698412699 acc after step 10 0.6782407407407407
1000 0.9195709228515625 acc 0.5873015873015873 acc after step 10 0.6099537037037037
1100 0.8853381276130676 acc 0.6130952380952381 acc after step 10 0.6412037037037037
1200 0.911009788513

9900 0.4925936460494995 acc 0.7708333333333334 acc after step 10 0.8148148148148148


In [23]:
# Let's sample again
seq = [3, 1, 2, 3, 0]

for i in range(10):
    out = model(torch.LongTensor([seq]).cuda())[0,-1]
    # Convert to probs
    out_p = torch.softmax(out, dim=-1)
    # Sample next token
    next_token = torch.multinomial(out_p, num_samples=1).item()
    seq.append(next_token)
    print(seq)

[3, 1, 2, 3, 0, 3]
[3, 1, 2, 3, 0, 3, 2]
[3, 1, 2, 3, 0, 3, 2, 3]
[3, 1, 2, 3, 0, 3, 2, 3, 1]
[3, 1, 2, 3, 0, 3, 2, 3, 1, 2]
[3, 1, 2, 3, 0, 3, 2, 3, 1, 2, 3]
[3, 1, 2, 3, 0, 3, 2, 3, 1, 2, 3, 3]
[3, 1, 2, 3, 0, 3, 2, 3, 1, 2, 3, 3, 1]
[3, 1, 2, 3, 0, 3, 2, 3, 1, 2, 3, 3, 1, 2]
[3, 1, 2, 3, 0, 3, 2, 3, 1, 2, 3, 3, 1, 2, 3]
