In [43]:
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange, einsum

import math
from collections import OrderedDict
torch.manual_seed(1234)

<torch._C.Generator at 0x76bf0fd46630>

### Sinusoidal Positional Encodings

This is used to inject some positional information into the embeddings of the token sequence. Since we compute the next token for every single token in parallel, we want some info about the position of a token within a sequence to be represented as well.
1. Unique encoding for each position (across all sequences)
2. Generalize to longer sequence than seen in training
3. Generated deterministically (so the model can learn it)
4. Linear relation between 2 encoded positions (again to help the model learn relationships)


Given a position $pos$ output a vector $d_{model}$ such that for each location $i$ in the vector the output is $$PE_{pos,2i} = \sin(pos/10000^{2i/d_{model}})$$ $$PE_{pos,2i+1} = \cos(pos/10000^{2i/d_{model}})$$
for even and odd indices respectively

we refactor as $$PE_{pos,2i} = \sin(pos.w)$$ $$PE_{pos,2i+1} = \cos(pos.w)$$ where $w=1/(10000^{2i/d_{model}}) for 0 <= 2i <= d_{model}$

![Sinusoidal PE visualization](fleetwood_sinusoidal.png "https://fleetwood.dev/posts/you-could-have-designed-SOTA-positional-encoding")
Look at the functions. For a dimension $i$ the $sin/cos$ function's wave starts out extremely quickly changing values for small $i$, and slows down a lot up to a wavelength of $10000*2\pi$.

#### Derivation to use tensor ops
$$\frac{1}{10000^{k/d_{model}}} = 10000^{-k/d_{model}} = \exp(\log(10000^{-k/d_{model}}))$$
$$ = \exp(-k/d_{model} * \log(10000))$$
this is basically
$$torch.exp(k * (-\frac{1}{d_{model}}) * \log(10000))$$

In [2]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000) -> None:
        """Construct the entire positional encoding."""
        super().__init__()
        w = torch.exp(torch.arange(0, d_model, 2) * (-1/d_model) * math.log(10000))
        PE = torch.zeros([max_len, d_model])
        positions = torch.arange(0, max_len)
        # We have 2 vectors
        assert (positions.shape == (max_len,))
        assert (w.shape == (d_model // 2,))
        # If we reshape positions into a row of columns [[0], [1], [2], ...] then 
        # positions @ w -> gives us a matrix of shape [max_len, d_model/2 ]
        positions = rearrange(positions, "(len column) -> len column", column=1)
        PE[:, 0::2] = torch.sin(positions * w)
        PE[:, 1::2] = torch.cos(positions * w)
        self.register_buffer("PE", PE) # this is not to be a learnable parameter
        # However we do want it to be moved along with model.to(device) 
    
    def forward(self):
        return self.PE

Just a reference implementation with for-loops to demonstrate speedup of torch ops.

In [3]:
# output pe -> [d_model]
def sinusoidal_position_encoding(pos: int, d_model: int = 512) -> torch.Tensor:
    # w = 1/10000**(2i/d_model)
    w = torch.tensor([1/(10000**(k/d_model)) for k in range(0, d_model, 2)])
    x_indices = pos * w 
    PE_even = torch.sin(x_indices)
    PE_odd = torch.cos(x_indices)
    PE_even = rearrange(PE_even, "(new_dim l) -> l new_dim", new_dim=1)
    PE_odd = rearrange(PE_odd, "(new_dim l) -> l new_dim", new_dim=1)
    interleaved_result = rearrange([PE_even, PE_odd], "function d_model element -> (d_model element function)")
    assert (interleaved_result.shape == (d_model,))
    return interleaved_result

def generate_position_encoding(max_len: int = 5000, d_model: int = 512) -> torch.Tensor:
    all_positions_encoding = torch.zeros([max_len, d_model], dtype=torch.float32)
    for position in range(max_len):
        all_positions_encoding[position, :] += sinusoidal_position_encoding(position, d_model)
    return all_positions_encoding

#### Note below the speed differential of the vectorized pytorch ops. Eye the units 👀

In [4]:
%%timeit
generate_position_encoding()

248 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%%timeit
SinusoidalPositionalEncoding(512)

627 μs ± 64.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
print(SinusoidalPositionalEncoding(d_model=512).PE.shape)

torch.Size([5000, 512])


In [7]:
def test_sinusoidal_positional_encoding_class():
    max_len = 5000
    d_model = 512
    pe_class = SinusoidalPositionalEncoding(d_model=d_model, max_len=max_len)
    
    torch.allclose(pe_class.forward(), generate_position_encoding(max_len=max_len, d_model=d_model))

test_sinusoidal_positional_encoding_class()

#### Function that calculates attention (single-headed for now).

#### A Decoder-Only Transformer Language Model

For now we are implementing it with single headed attention

In [8]:
"""
Dimension key:
L: sequence length
D: model dimension (d_model)
V: vocabulary size
F: feed-forward subnetwork's hidden size
K: size of each attention key or value (d_k,d_v,d_kv)
"""

"\nDimension key:\nL: sequence length\nD: model dimension (d_model)\nV: vocabulary size\nF: feed-forward subnetwork's hidden size\nK: size of each attention key or value (d_k,d_v,d_kv)\n"

In [9]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vector_size: int, vocab_size: int):
        super().__init__()
        self.d_model = vector_size
        self.scaling_factor = math.sqrt(self.d_model) # a detail in the paper
        self.lut = nn.Embedding(vocab_size, vector_size)
    
    def forward(self, input_indices: torch.Tensor):
        return self.lut(input_indices) # * self.scaling_factor

    def get_weights(self):
        return self.lut.weight

In [40]:
class FeedForward(nn.Module):
    def __init__(self, D: int, H: int = 2048):
        super().__init__()
        self.layers = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(D, H)),
            ("relu1", nn.ReLU()),
            ("linear2", nn.Linear(H, D)),
        ]))
    
    def forward(self, x):
        return self.layers.forward(x)



class Decoder(nn.Module):
    def __init__(self, d_model: int, d_k: int, d_v: int, P_drop: int):
        super().__init__()
        # attention function
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.P_drop = P_drop
        self.W_query_DK = nn.Parameter(torch.randn([self.d_model, self.d_k]), requires_grad=True)
        self.W_key_DK = nn.Parameter(torch.randn([self.d_model, self.d_k]), requires_grad=True)
        self.W_value_DK = nn.Parameter(torch.randn([self.d_model, self.d_v]), requires_grad=True)

        # normalize over the last dimension in shape [N, d_v], so each layer
        self.dropout1 = nn.Dropout(p=self.P_drop)
        self.layernorm1 = nn.LayerNorm(self.d_v)
        self.feedforward = FeedForward(D=self.d_model)
        self.dropout2 = nn.Dropout(p=self.P_drop)
        self.layernorm2 = nn.LayerNorm(self.d_model)
    
    def forward(self, input_BLD: torch.Tensor) -> torch.Tensor:
        assert(input_BLD.dim() == 3 and input_BLD.shape[-1] == self.d_model)
        input_token_length = input_BLD.shape[1]

        Query_BLK = torch.matmul(input_BLD, self.W_query_DK)
        Key_BLK = torch.matmul(input_BLD, self.W_key_DK)
        Value_BLK = torch.matmul(input_BLD, self.W_value_DK)

        attn_values_BLV = self.scaled_dot_prd_attention(Query_BLK, Key_BLK, Value_BLK)
        reg_attn_values_BLV = self.dropout1(attn_values_BLV)
        normalized_attn_values_BLV = self.layernorm1(reg_attn_values_BLV + input_BLD)
        assert(normalized_attn_values_BLV.shape == attn_values_BLV.shape)
        assert(normalized_attn_values_BLV.dim() == 3 and normalized_attn_values_BLV.shape[-1] == self.d_model)

        # print(normalized_attn_values.shape)
        ffn_BLD = self.feedforward(normalized_attn_values_BLV)
        reg_ffn_BLD = self.dropout1(ffn_BLD)

        normalized_ffn_BLD = self.layernorm2(reg_ffn_BLD + normalized_attn_values_BLV)
        return normalized_ffn_BLD
    
    def scaled_dot_prd_attention(self, Query_BLK, Key_BLK, Value_BLK, mask: bool = True) -> torch.Tensor:
        assert (Query_BLK.shape == Key_BLK.shape)

        Key_BKL = rearrange(Key_BLK, "B L K -> B K L")

        # QK_t_LL = einsum(Query_BLK, Key_BKL, "B i dk, B dk k -> B i k")

        qk_BLL = einsum(Query_BLK, Key_BKL, "B S_Q K, B K S_K -> B S_Q S_K")
        scaled_qk_BLL = qk_BLL / math.sqrt(self.d_model)
        # since we want to average the attended keys between the dimension d_k
        if mask:
            tril_qk = torch.tril(scaled_qk_BLL)
            scaled_qk_BLL = torch.where(tril_qk == 0, float("-inf"), scaled_qk_BLL)
        # print(scaled_QK_t_LL)
        # print(scaled_QK_t_LL.shape)
        # raise Exception
        weighted_keys_BLL = torch.softmax(scaled_qk_BLL, dim=-1)
        # print(weighted_keys)
        # raise Exception
        attention_output_BLV = einsum(weighted_keys_BLL, Value_BLK, "B S_Q L , B L D_V -> B S_Q D_V") 
        return attention_output_BLV



In [41]:
class Transformer(nn.Module):
    def __init__(self, d_model: int = 512, num_decoders: int = 6, maximum_sequence_length: int = 5000, P_drop: float = 0.1, vocabulary_size: int = 10000):

        super().__init__()
        assert (num_decoders >= 1)
        self.d_model = d_model
        self.d_k, self.d_v = d_model, d_model
        self.max_sequence_length = maximum_sequence_length
        self.vocabulary_size = vocabulary_size

        self.P_drop = P_drop

        self.positional_encodings = SinusoidalPositionalEncoding(d_model=self.d_model)

        self.emb_dropout = nn.Dropout(p = self.P_drop)
        self.embedding = EmbeddingLayer(vector_size=self.d_model, vocab_size=self.vocabulary_size)
        # print("embed info", self.embedding.get_weights().shape)
        self.decoder_stack = nn.Sequential()
        for _ in range(num_decoders):
            self.decoder_stack.append(Decoder(d_model = self.d_model, d_k = self.d_k, d_v = self.d_v, P_drop = self.P_drop))

        self.linear = nn.Linear(self.d_model, self.vocabulary_size)
        # make the loss reduce without shared matrixes for now.
        # with torch.no_grad():
        #     self.linear.weight = self.embedding.get_weights()
        self.softmax = nn.LogSoftmax(dim=1) # this is LogSoftmax so we can get the NLLLoss
    
    def forward(self, input_tokens_BL: torch.Tensor):
        assert(input_tokens_BL.dim() == 2 and input_tokens_BL.dtype == torch.int)
        B, L = input_tokens_BL.shape
        # how do we get from input tokens to indices?
        embed_tokens_BLD = self.embedding(input_tokens_BL)
        # print(embed_tokens.dim())
        # print(embed_tokens.shape)
        # print(sequence_length, self.d_model)
        assert(embed_tokens_BLD.dim() == 3 and embed_tokens_BLD.shape == torch.Size([B, L, self.d_model]))

        # Add positional encoding information
        # pos_enc = torch.zeros_like(embed_tokens)
        # for position in range(sequence_length):
        pos_enc_LD = self.positional_encodings()[None, :L, ...]
        embed_tokens_BLD += pos_enc_LD

        decoder_output_BLD = self.decoder_stack(embed_tokens_BLD)

        assert(decoder_output_BLD.dim() == 3 and decoder_output_BLD.shape == torch.Size([B, L, self.d_model]))
        linear_output_BLV = self.linear(decoder_output_BLD)
        next_token_probabilities_BLV = self.softmax(linear_output_BLV)

        return next_token_probabilities_BLV # this is of shape [seq_len, vocab_size] (each token has computed a probabilistic next token)
        """Now we map these back to the words based on the maximum token."""


In [42]:
def test_transformer_runs():
    """With a dummy input lets just test if the components of the transformer fit together"""
    input_ints = torch.tensor([1, 2, 3, 4, 3, 6, 5, 7, 8], dtype=torch.int).to("cuda")
    batched_input = torch.stack((input_ints, input_ints))
    model = Transformer(d_model=16).to("cuda")
    print(model.forward(batched_input).shape)

test_transformer_runs()

torch.Size([2, 9, 10000])


Dataset sample generation

In [52]:
import random

def create_reversed_training_sample(vocab_size: int, max_sequence_length):
    """Generate a palindromic tensor of token indices.
    Returns: [A,Rev(A)]
    """
    random_sample_L = torch.randint(2, vocab_size-1, (max_sequence_length,), dtype=torch.int)
    # shift the sequence right by 1
    # random_sample_L = torch.cat([start_token.unsqueeze(dim=0), random_sample_L], dim=0)

    return torch.cat((random_sample_L, torch.flip(random_sample_L, dims=(0,))))


sample_training_sample = create_reversed_training_sample(100000, 20)
print(sample_training_sample.shape)
print(sample_training_sample)
# print(rearrange(sample_training_sample, "s a -> (s a)", s=2))

torch.Size([40])
tensor([12287, 31818, 12844, 63123, 62556, 10011, 42987, 68538, 86210, 77636,
        65963, 24215, 32346, 90622, 67027, 66680, 24456, 27991, 59829, 63947,
        63947, 59829, 27991, 24456, 66680, 67027, 90622, 32346, 24215, 65963,
        77636, 86210, 68538, 42987, 10011, 62556, 63123, 12844, 31818, 12287],
       dtype=torch.int32)


In [114]:
START_TOKEN, END_TOKEN = torch.tensor(0, dtype=torch.int32), torch.tensor(1, dtype=torch.int32)

def shift_decoder_input_right(sample_BL: torch.Tensor, start_token: torch.Tensor) -> torch.Tensor:
    st_B = torch.zeros([sample_BL.shape[0], 1]).to(sample_BL.get_device(), dtype=torch.int32)
    return torch.cat([st_B, sample_BL], dim=1)

def pad_input_right(sample_BL: torch.Tensor, end_token: torch.Tensor) -> torch.Tensor:
    ed_B = torch.ones([sample_BL.shape[0], 1]).to(sample_BL.get_device(), dtype=torch.int32)
    return torch.cat([sample_BL, ed_B], dim=1)

In [None]:
def train_loop(dataloader, model: torch.nn.Module, loss_fn, optimizer, device, epochs):
    model.train()
    # for idx, sequence_2L in enumerate(dataset):
    for ep in range(epochs):
        print(f"Epoch {ep}\n")
        for idx, batch in enumerate(dataloader):
            # print(sequence.shape)
            """Given a palindromic sequence of tokens of length L we are training a model to reverse these tokens.
                Thus we can train with inputs 0..L/2 -> output L/2+1, 0..L/2+1 -> output L/2+2
            """
            batch_BL = batch.to(device)

            start_input_L = shift_decoder_input_right(batch_BL, START_TOKEN)
            seq_next_token_probabilities_LV = model(start_input_L)

            end_input_L = pad_input_right(batch_BL, END_TOKEN)

            palindrome_seq_len = batch_BL[0].numel()
            second_half_outputs_Vocab = rearrange(seq_next_token_probabilities_LV[:, palindrome_seq_len//2:, ...], "B L V -> (B L) V")
            second_half_targets_BxL = rearrange(end_input_L[:, palindrome_seq_len//2:, ...], "B L -> (B L)").to(torch.long)
            output = loss_fn(second_half_outputs_Vocab, second_half_targets_BxL)
            output.backward()
            optimizer.step()
            optimizer.zero_grad()
            if idx % 200 == 0:
                print(f"Step {idx}: Loss = {output}\n")



In [129]:
def test_loop(dataset, model: torch.nn.Module, loss_fn, device):
    size = len(dataset)
    print(f"Test set samples: {size}")
    model.eval()
    test_loss, correct = 0, 0
    total_tokens = 0
    with torch.no_grad():
        for idx, sequence in enumerate(dataset):
            total_tokens += sequence.numel()
            sequence = sequence.to(device)

            # Let's test full sequence matching, then test half sequence matching
            # if it fully matches the model has just memorized the test data.
            input_L = rearrange(sequence, "sequences seq_len -> (sequences seq_len)", sequences=2)
            start_input_L = shift_decoder_input_right(input_L, START_TOKEN)
            seq_next_token_probabilities_LV = model(start_input_L)

            end_input_L = pad_input_right(input_L, END_TOKEN)
            palindrome_seq_len = input_L.numel()

            test_loss += loss_fn(seq_next_token_probabilities_LV[palindrome_seq_len//2:, ...], end_input_L[ palindrome_seq_len//2:, ...].to(torch.long)).item()
            correct += (seq_next_token_probabilities_LV.argmax(dim=1) == end_input_L).sum().item()
    
    test_loss /= total_tokens
    correct /= total_tokens
    print(f"Test error: Accuracy: {100*correct}, Avg loss: {test_loss}\n")

### Train loop

In [135]:
vocab_size = 6
max_sequence_length = 4
dataset_size = 1000000
epochs = 10

device = torch.device("cuda")
train_dataset = [create_reversed_training_sample(vocab_size, max_sequence_length) for _ in range(int(dataset_size*.7))]
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)

model = Transformer(num_decoders=6, maximum_sequence_length=max_sequence_length, vocabulary_size=vocab_size).to(device)
loss = F.nll_loss
adam_opt = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), lr=0.001) # parameters from paper. No LR schedule, default 0.001
train_loop(dataloader=train_dataloader, model=model, loss_fn=loss, optimizer=adam_opt, device=device, epochs=epochs)

Epoch 0

Step 0: Loss = 2.234431505203247

Step 200: Loss = 1.951904535293579

Step 400: Loss = 1.945324182510376

Step 600: Loss = 1.9388620853424072

Step 800: Loss = 1.918243646621704

Step 1000: Loss = 1.904680609703064

Step 1200: Loss = 1.9881871938705444

Step 1400: Loss = 1.8915694952011108

Step 1600: Loss = 1.8725407123565674

Step 1800: Loss = 1.960019826889038

Step 2000: Loss = 1.8707574605941772

Step 2200: Loss = 1.797753095626831

Step 2400: Loss = 1.8037116527557373

Step 2600: Loss = 1.8954271078109741

Step 2800: Loss = 1.8452904224395752

Step 3000: Loss = 1.8310149908065796

Step 3200: Loss = 1.8663822412490845

Step 3400: Loss = 1.8033033609390259

Step 3600: Loss = 1.799922227859497

Step 3800: Loss = 1.7805511951446533

Step 4000: Loss = 1.8652575016021729

Step 4200: Loss = 1.8441060781478882

Step 4400: Loss = 1.8702433109283447

Step 4600: Loss = 1.8127212524414062

Step 4800: Loss = 1.7783381938934326

Step 5000: Loss = 1.8895353078842163

Step 5200: Loss = 

KeyboardInterrupt: 

Accuracy should be 50% ish since the model should only know the second half of the tokens.

In [132]:
test_dataset = [create_reversed_training_sample(vocab_size, max_sequence_length) for _ in range(int(dataset_size*.3))]
test_loop(test_dataset, model, loss, device)

Test set samples: 30000


EinopsError:  Error while processing rearrange-reduction pattern "sequences seq_len -> (sequences seq_len)".
 Input tensor shape: torch.Size([8]). Additional info: {'sequences': 2}.
 Wrong shape: expected 2 dims. Received 1-dim tensor.

In [91]:
def inference(model, sample, device: torch.device):
    model.eval()
    print(f"Sample {sample}\n")
    input_L = sample.to(device)
    result = []
    with torch.no_grad():
        total_tokens = sample.numel()
        print(f"Total Sample length {total_tokens}\n")

        
        start_input_L = shift_decoder_input_right(input_L, START_TOKEN)
        seq_next_token_probabilities_LV = model(start_input_L)

        end_input_L = pad_input_right(input_L, END_TOKEN)
        palindrome_seq_len = input_L.numel()

        result = seq_next_token_probabilities_LV.argmax(dim=1)

    print(result == end_input_L)
    print(result)
    return result

random_sample = [create_reversed_training_sample(vocab_size, max_sequence_length) for _ in range(1)][0]
random_sample = rearrange(random_sample, "a len -> (a len)")
# print(random_sample.shape)
# print(random_sample)
# print(random_sample[:2])
result = torch.tensor(inference(model, random_sample, device))
        



Sample tensor([2, 5, 7, 4, 3, 5, 5, 4, 4, 8, 7, 8, 8, 2, 4, 7, 6, 7, 3, 4, 4, 3, 7, 6,
        7, 4, 2, 8, 8, 7, 8, 4, 4, 5, 5, 3, 4, 7, 5, 2], dtype=torch.int32)

Total Sample length 40

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
         True], device='cuda:0')
tensor([1, 8, 8, 2, 2, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1], device='cuda:0')


  result = torch.tensor(inference(model, random_sample, device))
