In [1]:
import torch
from torch.nn import functional as F
from torch import nn
import math

In [38]:
d_model = 12
block_size = 6
batch_size = 4
dropout = 0.1

In [3]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self.embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, ix, targets=None): # ix is the index represented by a B,T,C tensor with character tokens

        logits = self.embedding_table(ix)

        if targets == None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(input=logits, target=targets)

        return logits, loss
    
    def generate(self, n_chars, ix):
        
        for _ in range(n_chars):

            logits, loss = self(ix) # B, T, C
            logits = logits[:,-1,:] # B, C -- we need to reshape to calculate probabilities
            probs = F.softmax(logits, dim=-1) # B, C
            next_ix = torch.multinomial(input=probs, num_samples=1)
            ix = torch.cat((ix, next_ix), dim=1)

        return ix

In [39]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size=d_model):
        super().__init__()
        self.key = nn.Linear(d_model, head_size, bias=False)
        self.query = nn.Linear(d_model, head_size, bias=False)
        self.value = nn.Linear(d_model, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        #self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        #wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

In [5]:
X = torch.randn(batch_size, block_size, d_model)
tril = torch.tril(torch.ones((block_size, block_size)))

q = X
k = X
v = X

m = Head(d_model)
y, qu,ke,va = m(X)

In [6]:
# One head attention

B,T,C = X.shape

query_weights = m.query.weight
key_weights = m.key.weight
value_weights = m.value.weight

q_layer = nn.Linear(d_model, d_model, bias=False)
k_layer = nn.Linear(d_model, d_model, bias=False)
v_layer = nn.Linear(d_model, d_model, bias=False)

q_layer.weight = query_weights
k_layer.weight = key_weights
v_layer.weight = value_weights

Q = q_layer(q) # 4,6,10
K = k_layer(k) # 4,6,10
V = v_layer(v) # 4,6,10

x = (Q @ torch.transpose(K, 1, 2)) * (d_model ** -0.5) # 4,6,10 @ 4,10,6 --> 4, 6, 6
x = x.masked_fill(tril[:T, :T] == 0, float('-inf')) # (B, T, T)
x = F.softmax(x, -1)
x = x @ V # 4,6,6 @ 4,6,10 --> 4,6,10

# Results are correct
torch.all(y == x)

tensor(True)

In [28]:
# Multi-head attention

n_heads = 2
d_model = 10
x = torch.randn(4,6,10)

q = x
k = x
v = x

B,T,C = x.shape
dk = d_model // n_heads
dv = d_model // n_heads
mask = tril.unsqueeze(0).unsqueeze(0)

q_layer = nn.Linear(d_model, d_model, bias=False)
k_layer = nn.Linear(d_model, d_model, bias=False)
v_layer = nn.Linear(d_model, d_model, bias=False)
att_proj = nn.Linear(d_model, d_model, bias=False)

q = q_layer(q) # 4,6,10
k = k_layer(k) # 4,6,10
v = v_layer(v) # 4,6,10

q = q.view(B,T,n_heads,C//n_heads).permute(0,2,1,3) 
k = k.view(B,T,n_heads,C//n_heads).permute(0,2,1,3) 
v = v.view(B,T,n_heads,C//n_heads).permute(0,2,1,3) 

x = Q @ K.transpose(-2,-1)
x = x.masked_fill(mask == 0, float('-inf')) # B,n_h,T,T 
x = F.softmax(x, dim=(-1)) # B,n_h,T,T 
x = x @ V # B,n_h,T,T @ B,T,n_h,C//n_h 
x = x.view(B,T, -1) # B,T,C

out = att_proj(x)

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3871, 0.6129, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3498, 0.4887, 0.1614, 0.0000, 0.0000, 0.0000],
         [0.2544, 0.1407, 0.4314, 0.1735, 0.0000, 0.0000],
         [0.1653, 0.4699, 0.1075, 0.1810, 0.0762, 0.0000],
         [0.1356, 0.2197, 0.1423, 0.1428, 0.1635, 0.1961]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5592, 0.4408, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2830, 0.4159, 0.3011, 0.0000, 0.0000, 0.0000],
         [0.2446, 0.2622, 0.4413, 0.0518, 0.0000, 0.0000],
         [0.1133, 0.0324, 0.4285, 0.3145, 0.1114, 0.0000],
         [0.1603, 0.1540, 0.2100, 0.1321, 0.1821, 0.1615]]],
       grad_fn=<SelectBackward0>)


In [33]:
class MultiHeadAttention(nn.Module):

    def __init__(self, n_heads, d_model, block_size, dropout=0.1):

        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.query = nn.Linear(d_model, d_model, bias=False)
        self.key = nn.Linear(d_model, d_model, bias=False)
        self.value = nn.Linear(d_model, d_model, bias=False)
        self.att_proj = nn.Linear(d_model, d_model, bias=False)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())

    def forward(self, x):

        # initialise as x (would be different for cross-attention)
        q = x
        k = x
        v = x
        B,T,C = x.shape
        dk = d_model // n_heads
        dv = d_model // n_heads

        # linear projections
        q = self.query(q) 
        k = self.key(k) 
        v = self.value(v) 

        # add number of heads
        q = q.view(B,T,n_heads,C//n_heads).permute(0,2,1,3)  
        k = k.view(B,T,n_heads,C//n_heads).permute(0,2,1,3) 
        v = v.view(B,T,n_heads,C//n_heads).permute(0,2,1,3) 

        # attention 
        x = q @ k.transpose(-2,-1)
        x = x * dk ** -0.5 # B,n_h,T,C @ B,n_h,C,T --> B,n_h,T,T
        x = x.masked_fill(self.mask == 0, float('-inf')) # B,n_h,T,T
        x = F.softmax(x, dim=(-1)) # B,n_h,T,T 
        x = x @ v  # B,n_h,T,T @ B,T,n_h,C//n_h 
        x = x.view(B,T, -1) # B,T,C
        out = self.att_proj(x)
        return out

In [None]:
# Multi-head attention robust to any n_head
# Padding will be added to the input tensor

# calculate the number of padding required
# add padding to input tensor - it must be 0s

d_model = 10
n_heads = 3

def calc_padding(d_model, n_heads):

    if d_model % n_heads != 0:
        pad = (n_heads * (d_model//n_heads+1)) - d_model
    else:
        0

    return pad

In [None]:
t4d = torch.empty(3, 3, 4, 2)
p1d = (1, 1) # pad last dim by 1 on each side
out = F.pad(t4d, p1d, "constant", 0)  # effectively zero padding
print(out.size())

torch.Size([3, 3, 4, 4])


In [None]:
# To Do
# Attention - Masked self attention
# input tensor B,T,C
# q, k, v projections - linear layers with no bias or activation function
# get projections by passing q,k,v through layers
# attention scores - q @ k where the mask is applied
# att_scores @ v
# create multi-head attention