In [1]:
from dataclasses import dataclass


@dataclass
class GeneralConfig:
    context_length: int


@dataclass
class TransformerLMConfig:
    vocab_size: int
    num_layers: int
    d_model: int
    num_heads: int
    d_ff: int = None

    # TODO: move context_length to a separate config
    context_length: int

    def __post_init__(self):
        self.d_ff = self.d_ff or 4 * self.d_model


general_config = {
    "vocab_size": 50_257,
    "context_length": 1_024,
}

model_configs = {
    "gpt-2-small": TransformerLMConfig(
        num_layers=12,
        d_model=768,
        num_heads=12,
    ),
    "gpt-2-medium": TransformerLMConfig(
        num_layers=24,
        d_model=1_024,
        num_heads=16,
    ),
    "gpt-2-large": TransformerLMConfig(
        num_layers=36,
        d_model=1_280,
        num_heads=20,
    ),
    "gpt-2-xl": TransformerLMConfig(
        num_layers=48,
        d_model=1_600,
        num_heads=25,
    ),
}

In [None]:
def count_transformer_lm_params(cfg: TransformerLMConfig):
    token_embeddings = cfg.vocab_size * cfg.d_model
    glu = 3 * cfg.d_model * cfg.d_ff
    qkvo_proj = 4 * cfg.d_model * cfg.d_model
    lm_head = cfg.vocab_size * cfg.d_model

    total = token_embeddings + cfg.num_layers * (glu + qkvo_proj) + lm_head

    return total


def count_transformer_lm_activations(cfg: TransformerLMConfig) -> int:
    layer_rms_norm = cfg.context_length * cfg.d_model

    attn_kqv = cfg.context_length * 3 * cfg.d_model
    attn_qk = cfg.context_length * cfg.context_length
    attn_softmax = attn_qk
    attn_values_weighted_sum = cfg.context_length * cfg.d_model
    attn_output_projection = cfg.context_length * cfg.d_model
    attn = attn_kqv + attn_qk + attn_softmax + attn_values_weighted_sum + attn_output_projection

    ffn_w1_mm = cfg.context_length * cfg.d_ff
    ffn_silu = ffn_w1_mm
    ffn_w2_mm = ffn_w1_mm
    ffn = ffn_w1_mm + ffn_silu + ffn_w2_mm

    layer = 2 * layer_rms_norm + attn + ffn

    final_rms_norm = cfg.context_length * cfg.d_model

    output_embedding = cfg.context_length * cfg.vocab_size  # (logits)

    total = cfg.num_layers * layer + final_rms_norm + output_embedding

    return total


def count_lm_transformer_adamw_memory_usage(cfg: TransformerLMConfig) -> int:
    params = count_transformer_lm_params(cfg)
    activations = count_transformer_lm_activations(cfg)

    return params + activations

TypeError: non-default argument 'context_length' follows default argument 'd_ff'