In [5]:
from typing import Annotated, Dict
import numpy.typing as npt

# Token embedding dimension. It's called D_MODEL in the original post.
# This is also State's dimension.
D_EMBED = 12288;
# The number of input tokens the transformer can handle at once.
# Also the length of the residual stream.
N_TOKEN = 1024;
# The number of unique subwords
N_VOCAB = 50000;

class Logits:
    """The logit for each subword
    """
    data: Annotated[list[float], N_VOCAB]
    def __init__(self, data):
        self.data = data

class TokenId:
    """Which subword
    """
    data: int
    def __init__(self, data):
        assert data < N_VOCAB and data >= 0
        self.data = data
    
    def __hash__(self):
        return hash(self.data)
    
    def __eq__(self, rhs):
        return isinstance(rhs, TokenId) and self.data == rhs.data
    

class State:
    """An item in the residual stream.
    """
    data: Annotated[list[float], D_EMBED]

class ResidualStream:
    """An abstract data structure that holds the current transformer's internal state.
    It keeps getting transformed by various phases of Transformer.run().
    You can think of this as an opaque POJO too, with a collection of key-value pairs,
    but the k-v pairs are encoded as an unordered set of N_TOKEN vectors of size D_EMBED.
    
    Note the meaning of this opaque object changes after each processing step.
    E.g. After the initial embedding step, this is a set of token embeddings.
    After the first attention step, this is a set of context-aware embeddings per token.
    After the first MLP, it's something mysterious :) Maybe a field in this opaque data structure can be
    a measure of how sentiment-positive the entire token collection is.
    """
    data: Annotated[list[State], N_TOKEN]
    def __init__(self, data):
        self.data = data

class Embedder:
    embedding_matrix: Dict[TokenId, Annotated[list[float], D_EMBED]]
    def __init__(self, embedding_matrix):
        self.embedding_matrix = embedding_matrix

    def run(self, input_tokens: list[TokenId]) -> ResidualStream:
        """Look up each token in the embedding matrix. Each token ID is mapped to D_EMBED vector.
        """
        states = [State(self.embedding_matrix[token_id] for token_id in input_tokens)]
        return ResidualStream(states)

class Unembedder:
    # TODO: Implement

class Transformer:

    def run(input_tokens: list[TokenId]) -> Logits:
        """A transformer accepts an unordered set of subwords / tokens and returns the logits of the next subword.
        
        [Positioning] The transformer doesn't actually know about the position of the subwords. We get around this by
        adding positional embedding into each token embedding. **But we don't do it in this exercise.
        """
        # TODO: Implement this
        # TODO: Embed
        return Logits([3.2]*N_VOCAB)

    def embed(input_tokens: list[TokenId]) -> ResidualStream:
        """Embeds the input tokens
        """

SyntaxError: unterminated string literal (detected at line 38) (1285603373.py, line 38)

In [2]:
import numpy as np