In [1]:
from dataclasses import dataclass


@dataclass
class TransformerLMConfig:
    batch_size: int
    context_length: int

    vocab_size: int
    num_layers: int
    d_model: int
    num_heads: int
    d_ff: int = None

    def __post_init__(self):
        self.d_ff = self.d_ff or 4 * self.d_model


shared = {
    "batch_size": 1,
    "context_length": 1_024,
    "vocab_size": 50_257,
}

model_configs = {
    "gpt-2-small": TransformerLMConfig(
        **shared,
        num_layers=12,
        d_model=768,
        num_heads=12,
    ),
    "gpt-2-medium": TransformerLMConfig(
        **shared,
        num_layers=24,
        d_model=1_024,
        num_heads=16,
    ),
    "gpt-2-large": TransformerLMConfig(
        **shared,
        num_layers=36,
        d_model=1_280,
        num_heads=20,
    ),
    "gpt-2-xl": TransformerLMConfig(
        **shared,
        num_layers=48,
        d_model=1_600,
        num_heads=25,
    ),
}

In [9]:
def count_transformer_lm_params(cfg: TransformerLMConfig):
    token_embeddings = cfg.vocab_size * cfg.d_model
    glu = 3 * cfg.d_model * cfg.d_ff
    qkvo_proj = 4 * cfg.d_model * cfg.d_model
    lm_head = cfg.vocab_size * cfg.d_model

    total = token_embeddings + cfg.num_layers * (glu + qkvo_proj) + lm_head

    return total


def count_transformer_lm_activations_per_batch_element(cfg: TransformerLMConfig) -> int:
    layer_rms_norm = cfg.context_length * cfg.d_model

    attn_kqv = cfg.context_length * 3 * cfg.d_model
    attn_qk = cfg.num_heads * cfg.context_length**2
    attn_softmax = attn_qk
    attn_values_weighted_sum = cfg.context_length * cfg.d_model
    attn_output_projection = cfg.context_length * cfg.d_model
    attn = attn_kqv + attn_qk + attn_softmax + attn_values_weighted_sum + attn_output_projection

    ffn_w1_mm = cfg.context_length * cfg.d_ff
    ffn_silu = ffn_w1_mm
    ffn_w3_mm = ffn_w1_mm
    ffn_w2_mm = cfg.context_length * cfg.d_model
    ffn = ffn_w1_mm + ffn_silu + ffn_w3_mm + ffn_w2_mm

    layer = 2 * layer_rms_norm + attn + ffn

    final_rms_norm = cfg.context_length * cfg.d_model

    output_embedding = cfg.context_length * cfg.vocab_size  # (logits)

    total = cfg.num_layers * layer + final_rms_norm + output_embedding

    return total


def count_lm_transformer_adamw_memory_usage(cfg: TransformerLMConfig, bytes_per_float: int = 4) -> int:
    params = count_transformer_lm_params(cfg) * bytes_per_float
    activations = count_transformer_lm_activations_per_batch_element(cfg) * bytes_per_float
    gradients = params
    optimizer_state = params * 2

    return {
        "params": params,
        "activations": activations * cfg.batch_size,
        "gradients": gradients,
        "optimizer_state": optimizer_state,
        "activations_per_batch_element": activations,
        "non_activations": params + gradients + optimizer_state,
        "total": params + activations + gradients + optimizer_state,
    }


memory_bytes = count_lm_transformer_adamw_memory_usage(
    TransformerLMConfig(
        **{
            **model_configs["gpt-2-xl"].__dict__,
            "batch_size": 1,
        }
    )
)

for key, value in memory_bytes.items():
    print(f"{key}: {value / 2**30:.2f} GB")

params: 7.92 GB
activations: 15.43 GB
gradients: 7.92 GB
optimizer_state: 15.85 GB
activations_per_batch_element: 15.43 GB
non_activations: 31.69 GB
total: 47.13 GB
