In [10]:
from typing import Annotated, Callable, Dict
import numpy.typing as npt

# Token embedding dimension. It's called D_MODEL in the original post.
# This is also State's dimension.
D_EMBED = 12288;
# The number of input tokens the transformer can handle at once.
# Also the length of the residual stream.
N_TOKEN = 1024;
# The number of unique subwords
N_VOCAB = 50000;

class Logits:
    """The logit for each subword
    """
    data: Annotated[list[float], N_VOCAB]
    def __init__(self, data):
        self.data = data

class TokenId:
    """Which subword
    """
    data: int
    def __init__(self, data):
        assert data < N_VOCAB and data >= 0
        self.data = data
    
    def __hash__(self):
        return hash(self.data)
    
    def __eq__(self, rhs):
        return isinstance(rhs, TokenId) and self.data == rhs.data
    

class State:
    """An item in the residual stream.
    """
    data: Annotated[list[float], D_EMBED]

class ResidualStream:
    """An abstract data structure that holds the current transformer's internal state.
    It keeps getting transformed by various phases of Transformer.run().
    You can think of this as an opaque POJO too, with a collection of key-value pairs,
    but the k-v pairs are encoded as an unordered set of N_TOKEN vectors of size D_EMBED.
    
    Note the meaning of this opaque object changes after each processing step.
    E.g. After the initial embedding step, this is a set of token embeddings.
    After the first attention step, this is a set of context-aware embeddings per token.
    After the first MLP, it's something mysterious :) Maybe a field in this opaque data structure can be
    a measure of how sentiment-positive the entire token collection is.
    """
    data: Annotated[list[State], N_TOKEN]
    def __init__(self, data):
        self.data = data

class TokenToStateEmbedder:
    embedding_matrix: Dict[TokenId, Annotated[list[float], D_EMBED]]
    def __init__(self, embedding_matrix):
        self.embedding_matrix = embedding_matrix

    def run(self, input_token: TokenId) -> State:
        """Convert a token ID to a D_EMBED vector.
        """
        return State(self.embedding_matrix[input_token])

class StateToTokenLogitsUnembedder:
    """
    Various views for unembedding:
    (1 - what's being coded) A list of functions, one per vocab item.
       Each takes a state, and returns 1 logit for that vocab.
       E.g. a single function for "cake" measures how "cake"-like this state is, and returns a scalar.
    (2) A list of D_EMBED-dimensional vectors, one for each vocab item.
       You then dot-product these D_EMBED vectors with the state to get how aligned the state vector
       is with the vocab item.
    (3) A matrix of shape (N_VOCAB, D_EMBED). Each row can be dot-product-ed with the state to get
    a measure of how vocab-item-like the state is.
        Remember: (N_VOCAB, D_EMBED) X (D_EMBED, 1) = (N_VOCAB, 1)
    """
    logit_computer_per_vocab: Annotated[list[Callable[[State], float]], N_VOCAB]

    def run(self, state: State) -> Logits:
        """Convert 1 state of the final residual stream to logits of the next subword.
        Each state is kind of like a summary of 1 prefix of the input context / sentence.
        """
        logit_per_vocab = []
        for vocab_i in range(N_VOCAB):
            logit_per_vocab.append(self.logit_computer_per_vocab[vocab_i](state))
        return logit_per_vocab
    
class Transformer:

    token_to_state_embedder: TokenToStateEmbedder
    state_to_token_logits_unembedder: StateToTokenLogitsUnembedder

    def run(self, input_tokens: Annotated[list[TokenId], N_TOKEN]) -> Annotated[list[Logits], N_TOKEN]:
        """A transformer accepts an ordered list of subwords / tokens and returns, for each position, the logits of the next subword.

        [For each prefix] It's important to note that the return value is NOT just the logits for the next subword following the
        context / sentence, but rather, the logits of the next subword FOR EACH non-empty prefix of the input.
        E.g. if the input is "this is my sentence", then we will return 4 logits. The first logit is for the next subword after "this",
        whereas the fourth / last logit is for the next subword after "this is my sentence".
        This might look silly at first, since you might think you only need the final logit, but it's useful to get multiple loss numbers
        for training.

        [Positioning] The transformer doesn't actually know about the position of the subwords. We get around this by
        adding positional embedding into each token embedding. **But we don't do it in this exercise.
        """

        # For each token, we embed into a D_EMBED vector.
        states = []
        for token_i in range(len(input_tokens)):
            state = self.token_to_state_embedder.run(input_tokens[token_i])
            states.append(state)

        # At this point, the initial residual stream is just the token embeddings.
        residual_stream = ResidualStream(states)

        # TODO: Then a lot of processing
        
        logits_for_each_prefix = []
        for prefix_summary_state in residual_stream.data:
            # Each state is a summary of a prefix for the input.
            # That is, states[2] is a summary of "I love you", whereas states[1] is a summary of "I love".
            # This happens because in the attention layers, we ensure that states[i] will never depend on
            # the initial token embeddings of tokens after i.
            logits = self.state_to_token_logits_unembedder.run(prefix_summary_state)
            logits_for_each_prefix.append(logits)
        return logits_for_each_prefix

    def embed(input_tokens: list[TokenId]) -> ResidualStream:
        """Embeds the input tokens
        """

SyntaxError: unterminated string literal (detected at line 68) (2870065211.py, line 68)

In [2]:
import numpy as np