In [17]:
# Allows self-reference of class in typing hint.
# https://stackoverflow.com/a/36193829/21196296
from __future__ import annotations
from typing import Annotated, Callable, Dict
import numpy as np
import numpy.typing as npt

"""Type annotations"""
NDArrayInt = npt.NDArray[np.int_]
NDArrayFloat = npt.NDArray[np.float64]

"""Constants"""
# Token embedding dimension. It's called D_MODEL in the original post.
# This is also State's dimension.
D_EMBED = 12288
# The number of neurons in the hidden layer.
D_MLP = 4 * D_EMBED
# The smaller dimension where attention computation is happening.
# It's D_HEAD in the original post.
D_ATTENTION = 128
# The dimension of query and key vectors. It doesn't have to be D_ATTENTION, but in practice it is.
D_ATTENTION_QUERY = D_ATTENTION
# Number of attention heads
N_HEADS = 96
assert N_HEADS == D_EMBED / D_ATTENTION
# The number of input tokens the transformer can handle at once.
# Also the length of the residual stream.
N_TOKEN = 1024
# The number of unique subwords
N_VOCAB = 50000
# The number of residual blocks. Each residual block is a sequence of attention, MLP, and normalization layers
# but in this code, we just have the attention and MLP layer.
# It's called N_BLOCKS in the original post
N_BLOCKS = 96

class Logits:
    """The logit for each subword
    """
    data: Annotated[NDArrayFloat, N_VOCAB]
    def __init__(self, data):
        self.data = data

class TokenId:
    """Which subword
    """
    data: int
    def __init__(self, data):
        assert data < N_VOCAB and data >= 0
        self.data = data
    
    def __hash__(self):
        return hash(self.data)
    
    def __eq__(self, rhs):
        return isinstance(rhs, TokenId) and self.data == rhs.data


class StateUpdate:
    """An update to a single state."""
    data: Annotated[NDArrayFloat, D_EMBED]

    def add(self, update: StateUpdate) -> StateUpdate:
        """Add another state update to this one.
        """
        return StateUpdate(self.data + update.data)

class State:
    """An item in the residual stream.
    """
    data: Annotated[NDArrayFloat, D_EMBED]

    def apply_update(self, update: StateUpdate) -> State:
        """Apply an update to this state and return as a new state.
        """
        return State(self.data + update.data)

class ResidualStreamUpdate:
    """An update to the residual stream.
    """
    state_updates: Annotated[list[StateUpdate], N_TOKEN]

    def add(self, update: ResidualStreamUpdate) -> ResidualStreamUpdate:
        """Add another residual stream update to this one.
        """
        if update is None:
            return self
        return ResidualStreamUpdate([self.state_updates[i].add(update.state_updates[i]) for i in range(N_TOKEN)])

class ResidualStream:
    """An abstract data structure that holds the current transformer's internal state.
    It keeps getting transformed by various phases of Transformer.run().
    You can think of this as an opaque POJO too, with a collection of key-value pairs,
    but the k-v pairs are encoded as an unordered set of N_TOKEN vectors of size D_EMBED.
    
    Note the meaning of this opaque object changes after each processing step.
    E.g. After the initial embedding step, this is a set of token embeddings.
    After the first attention step, this is a set of context-aware embeddings per token.
    After the first MLP, it's something mysterious :) Maybe a field in this opaque data structure can be
    a measure of how sentiment-positive the entire token collection is.
    """
    data: Annotated[list[State], N_TOKEN]
    def __init__(self, data):
        self.data = data
    

    def apply_update(self, update: ResidualStreamUpdate) -> ResidualStream:
        # Apply update to each state independently.
        # Each state update is just a vector addition.
        new_states = []
        for i in range(N_TOKEN):
            new_states.append(self.data[i].apply_update(update.data[i]))
        return ResidualStream(new_states)

class TokenToStateEmbedder:
    embedding_matrix: Dict[TokenId, Annotated[NDArrayFloat, D_EMBED]]
    def __init__(self, embedding_matrix):
        self.embedding_matrix = embedding_matrix

    def run(self, input_token: TokenId) -> State:
        """Convert a token ID to a D_EMBED vector.
        """
        return State(self.embedding_matrix[input_token])

class LogitFn:

    # This vector, when dot-producted with a state,
    # returns a measure of how attuned the state is to this vocab.
    vocab_likeness_vector: Annotated[NDArrayFloat, D_EMBED]

    def __init__(self, vocab_likeness_vector):
        self.vocab_likeness_vector = vocab_likeness_vector
    
    """A function that takes a state and returns a logit.
    """
    def run(self, state: State) -> float:
        return np.dot(self.vocab_likeness_vector, state.data)
    
class StateToTokenLogitsUnembedder:
    """
    Various views for unembedding:
    (1 - what's being coded) A list of functions, one per vocab item.
       Each takes a state, and returns 1 logit for that vocab.
       E.g. a single function for "cake" measures how "cake"-like this state is, and returns a scalar.
    (2) A list of D_EMBED-dimensional vectors, one for each vocab item.
       You then dot-product these D_EMBED vectors with the state to get how aligned the state vector
       is with the vocab item.
    (3) A matrix of shape (N_VOCAB, D_EMBED). Each row can be dot-product-ed with the state to get
    a measure of how vocab-item-like the state is.
        Remember: (N_VOCAB, D_EMBED) X (D_EMBED, 1) = (N_VOCAB, 1)
    """
    logit_computer_per_vocab: Dict[TokenId, LogitFn]


    def __init__(self, logit_computer_per_vocab):
        self.logit_computer_per_vocab = logit_computer_per_vocab

    def run(self, state: State) -> Logits:
        """Convert 1 state of the final residual stream to logits of the next subword.
        Each state is kind of like a summary of 1 prefix of the input context / sentence.

        E.g. given a state summarizing the substring "I love you", return these logits for the following subword:
          "more": 10
          "dear": 10
          "you": 0
        """
        logit_per_vocab = []
        for vocab_i in range(N_VOCAB):
            logit_per_vocab.append(self.logit_computer_per_vocab[vocab_i].run(state))
        return logit_per_vocab

class AttentionHead:
    # A single attention head has a single Q,K,V logic.

    # How you convert a state to k,q,v vectors.
    q_projector = Annotated[NDArrayFloat, D_ATTENTION_QUERY, D_EMBED]
    k_projector = Annotated[NDArrayFloat, D_ATTENTION_QUERY, D_EMBED]
    v_projector = Annotated[NDArrayFloat, D_ATTENTION, D_EMBED]
    attention_to_state_projector = Annotated[NDArrayFloat, D_EMBED, D_ATTENTION]

    def run(self, residual_stream: ResidualStream) -> ResidualStreamUpdate:
        ## Precompute lower-D space representations for each state.

        # For each state, precompute 3 vectors: q(state), k(state), v(state)
        # qs_per_state has N_TOKEN vectors, each of size D_ATTENTION_QUERY
        qs_per_state = [self.q_projector @ state.data for state in residual_stream.data]
        ks_per_state = [self.k_projector @ state.data for state in residual_stream.data]
        # vs_per_state has N_TOKEN vectors, each of size D_ATTENTION
        vs_per_state = [self.v_projector @ state.data for state in residual_stream.data]

        state_updates = []
        ## Compute each state update independently. This is where the state-to-state mixing happens.
        for target_state_i in range(N_TOKEN):
            # I can only look at states before and including me to concoct the state update.
            states_to_mix = residual_stream.data[0::target_state_i+1]
            # how_much_target_cares[j] is how much states[target_state_i] cares about states[j]
            how_much_target_cares = []
            for k_state_j in range(len(states_to_mix)):
                how_much_target_cares.append(np.dot(qs_per_state[target_state_i], ks_per_state[k_state_j]))
            how_much_target_cares = np.softmax(how_much_target_cares)

            # Given how much we care, compute the context / summary vector that the target state
            # should use to compute its update
            weighted_vs = np.zeros(D_ATTENTION)
            for v_state_i in range(len(states_to_mix)):
                weighted_vs += how_much_target_cares[v_state_i] * vs_per_state[v_state_i]

            # Reproject from D_ATTENTION to D_EMBED
            state_update = StateUpdate(self.attention_to_state_projector @ weighted_vs)
            state_updates.append(state_update)
        return ResidualStreamUpdate(state_updates)


class AttentionLayer:

    attention_heads: Annotated[list[AttentionHead], N_HEADS]

    def run(self, residual_stream: ResidualStream) -> ResidualStreamUpdate:
        # Each head is applied independently to the residual stream
        # Then we simply sum them up
        # TODO: Is each head weighted the same? Aren't they concatenated then there's a weight matrix to reproject?
        all_head_residual_stream_update = None
        for head in self.attention_heads:
            per_head_residual_stream_update = head.run(residual_stream)
            all_head_residual_stream_update = per_head_residual_stream_update.add(all_head_residual_stream_update)
        return all_head_residual_stream_update
        

class MLPNeuron:
    # To convert from state to a scalar.
    # Also the column vector in the (D_EMBED, D_MLP) matrix.
    read_vector: Annotated[NDArrayFloat, D_EMBED]
    # To convert from scalar to state update.
    # Also the column vector in the (D_MLP, D_EMBED) matrix.
    write_vector: Annotated[NDArrayFloat, D_EMBED]
    """A neuron takes in a state and returns a state update. The final state update of the
    MLP neuron is just the sum of all the per-neuron state updates
    """
    def run(self, state: State) -> StateUpdate:
        scalar = np.dot(self.read_vector, state.data)
        # IRL they use GELU for the non-linearity. I just pick tanh because it's in numpy.
        scalar = np.tanh(scalar)
        return StateUpdate(self.write_vector * scalar)


class MLPLayerPerState:
        mlp_neurons: Annotated[list[MLPNeuron], D_MLP]

        def __init__(self, mlp_neurons) -> None:
            self.mlp_neurons = mlp_neurons

        def run(self, state: State) -> StateUpdate:
            """MLP layer is applied to each state in the residual stream individually.
            Given a state, it returns a state update.
            
            Here are various ways of looking at the MLP layer:
            (1) Matrix. We start with residual stream of shape (N_TOKEN, D_EMBED).
                (a) We multiply (N_TOKEN, D_EMBED) by (D_EMBED, D_MLP) to get (N_TOKEN, D_MLP)
                (b) Apply non-linearity element wise
                (c) We multiply (N_TOKEN, D_MLP) by (D_MLP, D_EMBED) so we can add back to residual stream
            (2) Per state. We map the (1) Matrix view to a per-state view by referring to the alphabet steps (a), (b), ...
                - We have N_TOKEN state vectors, each of size D_EMBED.
                (a) We multiply (N_TOKEN, D_EMBED) by (D_EMBED, D_MLP) to get (N_TOKEN, D_MLP)
                  - Each D_EMBED state vector is being processed independently by the same set of D_MLP neurons.
                    We now discuss what happens to each D_EMBED state vector.
                  - We convert the D_EMBED state vector to a D_MLP vector by feeding it into a set of D_MLP neurons.
                    - Each neuron takes a D_EMBED vector and returns a scalar independently
                      - It is able to do so because each neuron has its own D_EMBED vector to do dot product with.
                      This inner vector, in a sense, is what the neuron is looking for from a single state.
                      - To relate to interpretation (1), each neuron is a column vector in the (D_EMBED, D_MLP) matrix.
                    - Now that we have D_MLP scalars, we just concatenate them.
                (b) Apply non-linearity element wise
                  - Each of the D_MLP scalars is then passed through non-linearity in an element-wise fashion.
                (c) We multiply (N_TOKEN, D_MLP) by (D_MLP, D_EMBED) so we can add back to residual stream
                  - Each of the D_MLP neurons knows how much to use its 1 scalar to contribute to each of the D_EMBED
                  - It does so because internally it has another D_EMBED vector to figure out how to distribute its scalar
                    to the final state update vector.
            """
            # Run each neuron independently on the state, then sum up the state updates.
            final_state_update = np.zeros(D_EMBED)
            for neuron in self.mlp_neurons:
                state_update = neuron.run(state)
                final_state_update += state_update.data
            return StateUpdate(final_state_update)

class Block:

    attention_layer: AttentionLayer
    mlp_layer_per_state: MLPLayerPerState
    def run(self, residual_stream: ResidualStream) -> ResidualStream:
        """A residual block processes a residual stream and returns one with a new meaning.
        Conceptually, it's a function that accepts an opaque data structure and returns a data structure
        with a new meaning. Don't be fooled by the fact that we are both receiving and returning ResidualStream.
        We might as well call it ResidualStreamLayer1 and ResidualStreamLayer2.
        E.g. it's as if it's a function that accepts the class PerTokenEmbedding and returns a totally different
        class called ContextualAwareEmbedding, or some other data structure with a more advanced meaning.
        """
        
        attention_layer_update = self.attention_layer.run(residual_stream)
        residual_stream = residual_stream.apply_update

        # MLP layer is applied to each state in the residual stream individually.
        new_states = residual_stream.data
        for state_i in range(len(residual_stream.data)):
            state_update = self.mlp_layer_per_state.run(new_states[state_i])
            new_states[state_i] = new_states[state_i].apply_update(state_update)
        return ResidualStream(new_states)

class Transformer:

    token_to_state_embedder: TokenToStateEmbedder
    state_to_token_logits_unembedder: StateToTokenLogitsUnembedder
    blocks: Annotated[list[Block], N_BLOCKS]

    def __init__(self, token_to_state_embedder, state_to_token_logits_unembedder, blocks):
        self.token_to_state_embedder = token_to_state_embedder
        self.state_to_token_logits_unembedder = state_to_token_logits_unembedder
        self.blocks = blocks

    def run(self, input_tokens: Annotated[list[TokenId], N_TOKEN]) -> Annotated[list[Logits], N_TOKEN]:
        """A transformer accepts an ordered list of subwords / tokens and returns, for each position, the logits of the next subword.

        [For each prefix] It's important to note that the return value is NOT just the logits for the next subword following the
        context / sentence, but rather, the logits of the next subword FOR EACH non-empty prefix of the input.
        E.g. if the input is "this is my sentence", then we will return 4 logits. The first logit is for the next subword after "this",
        whereas the fourth / last logit is for the next subword after "this is my sentence".
        This might look silly at first, since you might think you only need the final logit, but it's useful to get multiple loss numbers
        for training.

        [Positioning] The transformer doesn't actually know about the position of the subwords. We get around this by
        adding positional embedding into each token embedding. **But we don't do it in this exercise.
        """

        # For each token, we embed into a D_EMBED vector.
        states = []
        for token_i in range(len(input_tokens)):
            state = self.token_to_state_embedder.run(input_tokens[token_i])
            states.append(state)

        # At this point, the initial residual stream is just the token embeddings.
        residual_stream = ResidualStream(states)

        # A lot of heavy processing on the residual stream by going through a sequence of blocks.
        # The meaning of the residual stream changes after each block.
        for block in (self.blocks):
            residual_stream = block.run(residual_stream)

        # At this point, the residual stream is N_TOKENS states.
        # Each state is a summary of a prefix for the input.
        # That is, states[2] is a summary of "I love you", whereas states[1] is a summary of "I love".
        # This happens because in the attention layers, we ensure that states[i] will never depend on
        # the initial token embeddings of tokens after i.
        logits_for_each_prefix = []
        for prefix_summary_state in residual_stream.data:
            logits = self.state_to_token_logits_unembedder.run(prefix_summary_state)
            logits_for_each_prefix.append(logits)
        return logits_for_each_prefix

IndentationError: expected an indented block after function definition on line 163 (738994268.py, line 166)

In [2]:
import numpy as np