In [2]:
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, einsum

import math
from collections import OrderedDict
torch.manual_seed(1234)

<torch._C.Generator at 0x7af0140f2130>

### Sinusoidal Positional Encodings

This is used to inject some positional information into the embeddings of the token sequence. Since we compute the next token for every single token in parallel, we want some info about the position of a token within a sequence to be represented as well.
1. Unique encoding for each position (across all sequences)
2. Generalize to longer sequence than seen in training
3. Generated deterministically (so the model can learn it)
4. Linear relation between 2 encoded positions (again to help the model learn relationships)


Given a position $pos$ output a vector $d_{model}$ such that for each location $i$ in the vector the output is $$PE_{pos,2i} = \sin(pos/10000^{2i/d_{model}})$$ $$PE_{pos,2i+1} = \cos(pos/10000^{2i/d_{model}})$$
for even and odd indices respectively

we refactor as $$PE_{pos,2i} = \sin(pos.w)$$ $$PE_{pos,2i+1} = \cos(pos.w)$$ where $w=1/(10000^{2i/d_{model}}) for 0 <= 2i <= d_{model}$

![Sinusoidal PE visualization](fleetwood_sinusoidal.png "https://fleetwood.dev/posts/you-could-have-designed-SOTA-positional-encoding")
Look at the functions. For a dimension $i$ the $sin/cos$ function's wave starts out extremely quickly changing values for small $i$, and slows down a lot up to a wavelength of $10000*2\pi$.

#### Derivation to use tensor ops
$$\frac{1}{10000^{k/d_{model}}} = 10000^{-k/d_{model}} = \exp(\log(10000^{-k/d_{model}}))$$
$$ = \exp(-k/d_{model} * \log(10000))$$
this is basically
$$torch.exp(k * (-\frac{1}{d_{model}}) * \log(10000))$$

In [4]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000) -> None:
        """Construct the entire positional encoding."""
        super().__init__()
        w = torch.exp(torch.arange(0, d_model, 2) * (-1/d_model) * math.log(10000))
        PE = torch.zeros([max_len, d_model])
        positions = torch.arange(0, max_len)
        # We have 2 vectors
        assert (positions.shape == (max_len,))
        assert (w.shape == (d_model // 2,))
        # If we reshape positions into a row of columns [[0], [1], [2], ...] then 
        # positions @ w -> gives us a matrix of shape [max_len, d_model/2 ]
        positions = rearrange(positions, "(len column) -> len column", column=1)
        PE[:, 0::2] = torch.sin(positions * w)
        PE[:, 1::2] = torch.cos(positions * w)
        self.register_buffer("PE", PE) # this is not to be a learnable parameter
        # However we do want it to be moved along with model.to(device) 
    
    def forward(self):
        return self.PE

In [5]:
# output pe -> [d_model]
def sinusoidal_position_encoding(pos: int, d_model: int = 512) -> torch.Tensor:
    # w = 1/10000**(2i/d_model)
    w = torch.tensor([1/(10000**(k/d_model)) for k in range(0, d_model, 2)])
    x_indices = pos * w 
    PE_even = torch.sin(x_indices)
    PE_odd = torch.cos(x_indices)
    PE_even = rearrange(PE_even, "(new_dim l) -> l new_dim", new_dim=1)
    PE_odd = rearrange(PE_odd, "(new_dim l) -> l new_dim", new_dim=1)
    interleaved_result = rearrange([PE_even, PE_odd], "function d_model element -> (d_model element function)")
    assert (interleaved_result.shape == (d_model,))
    return interleaved_result

In [6]:
def generate_position_encoding(max_len: int = 5000, d_model: int = 512) -> torch.Tensor:
    all_positions_encoding = torch.zeros([max_len, d_model], dtype=torch.float32)
    print(all_positions_encoding.shape)
    for position in range(max_len):
        all_positions_encoding[position, :] += sinusoidal_position_encoding(position, d_model)
    return all_positions_encoding

#### Note below the speed differential of the vectorized pytorch ops. Eye the units 👀

In [7]:
%%timeit
generate_position_encoding()

torch.Size([5000, 512])
torch.Size([5000, 512])
torch.Size([5000, 512])
torch.Size([5000, 512])
torch.Size([5000, 512])
torch.Size([5000, 512])
torch.Size([5000, 512])
torch.Size([5000, 512])
243 ms ± 4.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit
SinusoidalPositionalEncoding(512)

338 μs ± 1.77 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
print(SinusoidalPositionalEncoding(d_model=512).PE.shape)

torch.Size([5000, 512])


In [10]:
def test_sinusoidal_positional_encoding_class():
    max_len = 5000
    d_model = 512
    pe_class = SinusoidalPositionalEncoding(d_model=d_model, max_len=max_len)
    
    torch.allclose(pe_class.forward(), generate_position_encoding(max_len=max_len, d_model=d_model))

test_sinusoidal_positional_encoding_class()

torch.Size([5000, 512])


#### Function that calculates attention (single-headed for now).

#### A Decoder-Only Transformer Language Model

For now we are implementing it with single headed attention

In [None]:
"""
Dimension key:
L: sequence length
D: model dimension (d_model)
V: vocabulary size
F: feed-forward subnetwork's hidden size
K: size of each attention key or value (d_k,d_v,d_kv)
"""

In [64]:
class FeedForward(nn.Module):
    def __init__(self, D: int, H: int = 2048):
        super().__init__()
        self.layers = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(D, H)),
            ("relu1", nn.ReLU()),
            ("linear2", nn.Linear(H, D)),
        ]))
    
    def forward(self, x):
        return self.layers.forward(x)



class Decoder(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, P_drop: int):
        super().__init__()
        # attention function
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.P_drop = P_drop
        self.W_query_DK = nn.Parameter(torch.randn([self.d_model, self.d_k]), requires_grad=True)
        self.W_key_DK = nn.Parameter(torch.randn([self.d_model, self.d_k]), requires_grad=True)
        self.W_value_DK = nn.Parameter(torch.randn([self.d_model, self.d_v]), requires_grad=True)

        # normalize over the last dimension in shape [N, d_v], so each layer
        self.dropout1 = nn.Dropout(p=self.P_drop)
        self.layernorm1 = nn.LayerNorm(self.d_v)
        self.feedforward = FeedForward(D=self.d_model)
        self.dropout2 = nn.Dropout(p=self.P_drop)
        self.layernorm2 = nn.LayerNorm(self.d_model)
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert(input.dim() == 2 and input.shape[1] == self.d_model)
        input_token_length = input.shape[0]
        Q = torch.matmul(input, self.W_query_DK)
        K = torch.matmul(input, self.W_key_DK)
        V = torch.matmul(input, self.W_value_DK)

        attn_values = self.scaled_dot_prd_attention(Q, K, V)
        reg_attn_values = self.dropout1(attn_values)
        normalized_attn_values = self.layernorm1(reg_attn_values + input)
        assert(normalized_attn_values.shape == attn_values.shape)
        assert(normalized_attn_values.dim() == 2 and normalized_attn_values.shape[1] == self.d_model)

        # print(normalized_attn_values.shape)
        ffn_output = self.feedforward(normalized_attn_values)
        reg_ffn_output = self.dropout1(ffn_output)

        normalized_ffn = self.layernorm2(reg_ffn_output + normalized_attn_values)
        return normalized_ffn
    
    def scaled_dot_prd_attention(self, Q, K, V, mask: bool = True) -> torch.Tensor:
        assert (Q.shape == K.shape)

        K_t = rearrange(K, "dm dk -> dk dm")

        QK_t = einsum(Q, K_t, "i dk, dk k -> i k")
        scaled_QK_t = QK_t / math.sqrt(self.d_model)
        # since we want to average the attended keys between the dimension d_k
        tril_qk = torch.tril(scaled_QK_t)
        scaled_QK_t = torch.where(tril_qk == 0, float("-inf"), scaled_QK_t)
        # print(scaled_QK_t)
        # print(scaled_QK_t.shape)
        # raise Exception
        weighted_keys = torch.softmax(scaled_QK_t, dim=1)
        # print(weighted_keys[:10][:10])
        # raise Exception
        attention_output = einsum(weighted_keys, V, "input_len input_len, input_len dv -> input_len dv")
        return attention_output



In [65]:

class EmbeddingLayer(nn.Module):
    def __init__(self, vector_size: int, vocab_size: int):
        super().__init__()
        self.d_model = vector_size
        self.scaling_factor = math.sqrt(self.d_model) # a detail in the paper
        self.lut = nn.Embedding(vocab_size, vector_size)
    
    def forward(self, input_indices: torch.Tensor):
        return self.lut(input_indices) # * self.scaling_factor

    def get_weights(self):
        return self.lut.weight


In [72]:
class Transformer(nn.Module):
    def __init__(self, d_model: int = 512, num_decoders: int = 6, maximum_sequence_length: int = 5000, P_drop: float = 0.1, vocabulary_size: int = 10000):

        super().__init__()
        assert (num_decoders >= 1)
        self.d_model = d_model
        self.d_k, self.d_v = d_model, d_model
        self.max_sequence_length = maximum_sequence_length
        self.vocabulary_size = vocabulary_size

        self.P_drop = P_drop

        self.positional_encodings = SinusoidalPositionalEncoding(d_model=self.d_model)

        self.emb_dropout = nn.Dropout(p = self.P_drop)
        self.embedding = EmbeddingLayer(vector_size=self.d_model, vocab_size=self.vocabulary_size)
        # print("embed info", self.embedding.get_weights().shape)
        self.decoder_stack = nn.Sequential()
        for _ in range(num_decoders):
            self.decoder_stack.append(Decoder(d_model = self.d_model, d_k = self.d_k, d_v = self.d_v, P_drop = self.P_drop))

        self.linear = nn.Linear(self.d_model, self.vocabulary_size)
        # make the loss reduce without shared matrixes for now.
        # with torch.no_grad():
        #     self.linear.weight = self.embedding.get_weights()
        self.softmax = nn.LogSoftmax(dim=1) # this is LogSoftmax so we can get the NLLLoss
    
    def forward(self, input_tokens_L: torch.Tensor):
        assert(input_tokens_L.dim() == 1 and input_tokens_L.dtype == torch.int)
        L = input_tokens_L.shape[0]
        # how do we get from input tokens to indices?
        embed_tokens_LD = self.embedding(input_tokens_L)
        # print(embed_tokens.dim())
        # print(embed_tokens.shape)
        # print(sequence_length, self.d_model)
        assert(embed_tokens_LD.dim() == 2 and embed_tokens_LD.shape == torch.Size([L, self.d_model]))

        # Add positional encoding information
        # pos_enc = torch.zeros_like(embed_tokens)
        # for position in range(sequence_length):
        pos_enc_LD = self.positional_encodings()[:L, ...]
        embed_tokens_LD += pos_enc_LD

        decoder_output_LD = self.decoder_stack(embed_tokens_LD)

        assert(decoder_output_LD.dim() == 2 and decoder_output_LD.shape == torch.Size([L, self.d_model]))
        linear_output_LV = self.linear(decoder_output_LD)
        next_token_probabilities_LV = self.softmax(linear_output_LV)

        return next_token_probabilities_LV # this is of shape [seq_len, vocab_size] (each token has computed a probabilistic next token)
        """Now we map these back to the words based on the maximum token."""


In [73]:
def test_transformer_runs():
    """With a dummy input lets just test if the components of the transformer fit together"""
    input_ints = torch.tensor([1, 2, 3, 4, 3, 6, 5, 7, 8], dtype=torch.int).to("cuda")
    model = Transformer(d_model=16).to("cuda")
    print(model.forward(input_ints).shape)

test_transformer_runs()

torch.Size([9, 10000])


Dataset sample generation

In [74]:
import random
def create_random_sequence(vocab_size: int, max_sequence_length) -> torch.Tensor:
    # return torch.randint(0, vocab_size-1, (random.randint(10, (max_sequence_length/2)-1),), dtype=torch.int)
    return torch.randint(0, vocab_size-1, (max_sequence_length,), dtype=torch.int)

def create_reversed_training_sample(vocab_size: int, max_sequence_length):
    """Generate a palindromic tensor of token indices.
    Returns: [A,Rev(A)]
    """
    random_sample = create_random_sequence(vocab_size=vocab_size, max_sequence_length=max_sequence_length)
    return torch.stack((random_sample, torch.flip(random_sample, dims=(0,))))

sample_training_sample = create_reversed_training_sample(100000, 20)
print(sample_training_sample.shape)
print(sample_training_sample)
# print(rearrange(sample_training_sample, "s a -> (s a)", s=2))

torch.Size([2, 20])
tensor([[14509, 81570, 56940, 41326, 58356, 50691, 23291, 57432,  6395, 46661,
         67340, 88120,  9305, 29989, 92979, 49443, 92894, 64343, 98965, 82685],
        [82685, 98965, 64343, 92894, 49443, 92979, 29989,  9305, 88120, 67340,
         46661,  6395, 57432, 23291, 50691, 58356, 41326, 56940, 81570, 14509]],
       dtype=torch.int32)


In [75]:
def train_loop(dataset, model: torch.nn.Module, loss_fn, optimizer, device):
    model.train()
    for idx, sequence in enumerate(dataset):
        # print(sequence.shape)
        """Given a palindromic sequence of tokens of length L we are training a model to reverse these tokens.
            Thus we can train with inputs 0..L/2 -> output L/2+1, 0..L/2+1 -> output L/2+2
        """
        sequence = sequence.to(device)
        # prompt, output = sequence[0,...], sequence[1, ...]
        # assert prompt.dim() == 1, output.dim() == 1
        input = rearrange(sequence, "sequences seq_len -> (sequences seq_len)", sequences=2)

        # seq_len = int(prompt.shape)
        # print(input.dtype)

        seq_next_token_probabilities = model(input) # [input_len, V]

        # print(seq_next_token_probabilities.shape, seq_next_token_probabilities.dtype)
        output = loss_fn(seq_next_token_probabilities, input.to(torch.long))
        output.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx % 200 == 0:
            print(f"Step {idx} : Loss = {output}\n")



In [76]:
def test_loop(dataset, model: torch.nn.Module, loss_fn, device):
    size = len(dataset)
    print(f"Test set samples: {size}")
    model.eval()
    test_loss, correct = 0, 0
    total_tokens = 0
    with torch.no_grad():
        for idx, sequence in enumerate(dataset):
            total_tokens += sequence.numel()
            sequence = sequence.to(device)

            # Let's test full sequence matching, then test half sequence matching
            # if it fully matches the model has just memorized the test data.
            input = rearrange(sequence, "sequences seq_len -> (sequences seq_len)", sequences=2)
            pred = model(input)
            test_loss += loss_fn(pred, input.to(torch.long)).item()
            correct += (pred.argmax(dim=1) == input).sum().item()
    
    test_loss /= total_tokens
    correct /= total_tokens
    print(f"Test error: Accuracy: {100*correct}, Avg loss: {test_loss}\n")

### Train loop

In [78]:
vocab_size = 100
max_sequence_length = 50
dataset_size = 10000

device = torch.device("cuda")
train_dataset = [create_reversed_training_sample(vocab_size, max_sequence_length) for _ in range(int(dataset_size*.7))]

model = Transformer(num_decoders=6, maximum_sequence_length=max_sequence_length, vocabulary_size=vocab_size).to(device)
loss = F.nll_loss
adam_opt = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), lr=0.001) # parameters from paper. No LR schedule, default 0.001
train_loop(dataset=train_dataset, model=model, loss_fn=loss, optimizer=adam_opt, device=device)

Step 0 : Loss = 4.794222831726074

Step 200 : Loss = 0.2318265736103058

Step 400 : Loss = 0.10798325389623642

Step 600 : Loss = 0.06682287901639938

Step 800 : Loss = 0.15197552740573883

Step 1000 : Loss = 0.13703179359436035

Step 1200 : Loss = 0.08748020231723785

Step 1400 : Loss = 0.0799342542886734

Step 1600 : Loss = 0.06832216680049896

Step 1800 : Loss = 0.043675877153873444

Step 2000 : Loss = 0.09032051265239716

Step 2200 : Loss = 0.05185873433947563

Step 2400 : Loss = 0.5307435989379883

Step 2600 : Loss = 0.04221052676439285

Step 2800 : Loss = 0.15600234270095825

Step 3000 : Loss = 0.09824006259441376

Step 3200 : Loss = 0.04663380607962608

Step 3400 : Loss = 0.05471770837903023

Step 3600 : Loss = 0.11089969426393509

Step 3800 : Loss = 0.04336409270763397

Step 4000 : Loss = 0.23348098993301392

Step 4200 : Loss = 0.04427555203437805

Step 4400 : Loss = 0.0527845062315464

Step 4600 : Loss = 0.033761076629161835

Step 4800 : Loss = 0.037306416779756546

Step 5000 

Accuracy should be 50% ish since the model should only know the second half of the tokens.

In [80]:
test_dataset = [create_reversed_training_sample(vocab_size, max_sequence_length) for _ in range(int(dataset_size*.3))]
test_loop(test_dataset, model, loss, device)

Test set samples: 3000
Test error: Accuracy: 98.98933333333333, Avg loss: 0.00041133413260336965



In [81]:
def inference(model, sample, device: torch.device):
    model.eval()
    print(f"Sample {sample}\n")
    sample = sample.to(device)
    result = []
    with torch.no_grad():
        total_tokens = sample.numel()
        print(f"Total Sample length {total_tokens}\n")
        for input_window in range(1, total_tokens+1):
            input = sample[:input_window]

            # print(input.shape, input.dtype)
            # break
            all_next_token_preds = model(input)
            # print(all_next_token_preds)
            # print(all_next_token_preds.shape)
            # break
            pred_next_token = all_next_token_preds[-1, ...].argmax(0)
            result.append(pred_next_token.cpu())
            # print(f"Predicted next token {pred_next_token}, Actual next token {sample[input_window]}\n")
    


    return result

random_sample = [create_reversed_training_sample(vocab_size, max_sequence_length) for _ in range(1)][0]
random_sample = rearrange(random_sample, "a len -> (a len)")
# print(random_sample.shape)
# print(random_sample)
# print(random_sample[:2])
result = torch.tensor(inference(model, random_sample, device))
        



Sample tensor([56, 90, 28, 90, 84, 21, 33, 39, 35, 15, 87, 46, 75, 96, 16, 71, 48, 49,
        13, 69, 52, 21, 21, 76, 62, 40, 73, 25, 64, 93, 22, 54, 89, 10, 22, 59,
        70,  2,  1, 88, 72, 61, 70, 47, 56, 76,  1, 55, 97, 96, 96, 97, 55,  1,
        76, 56, 47, 70, 61, 72, 88,  1,  2, 70, 59, 22, 10, 89, 54, 22, 93, 64,
        25, 73, 40, 62, 76, 21, 21, 52, 69, 13, 49, 48, 71, 16, 96, 75, 46, 87,
        15, 35, 39, 33, 21, 84, 90, 28, 90, 56], dtype=torch.int32)

Total Sample length 100



In [82]:
print(result == random_sample)

tensor([False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True])


In [83]:
result

tensor([65, 90, 28, 90, 84, 21, 33, 39, 35, 15, 87, 46, 75, 96, 16, 71, 48, 49,
        13, 69, 52, 21, 21, 76, 62, 40, 73, 25, 64, 93, 22, 54, 89, 10, 22, 59,
        70,  2,  1, 88, 72, 61, 70, 47, 56, 76,  1, 55, 97, 96, 96, 97, 55,  1,
        76, 56, 47, 70, 61, 72, 88,  1,  2, 70, 59, 22, 10, 89, 54, 22, 93, 64,
        25, 73, 40, 62, 76, 21, 21, 52, 69, 13, 49, 48, 71, 16, 96, 75, 46, 87,
        15, 35, 39, 33, 21, 84, 90, 28, 90, 56])