In [1]:
general_config = {
    "vocab_size": 50_257,
    "context_length": 1_024,
}

model_configs = {
    "gpt-2-small": {
        "num_layers": 12,
        "d_model": 768,
        "num_heads": 12,
        "d_ff": 4 * 768,
    },
    "gpt-2-medium": {
        "num_layers": 24,
        "d_model": 1_024,
        "num_heads": 16,
        "d_ff": 4 * 1024,
    },
    "gpt-2-large": {
        "num_layers": 36,
        "d_model": 1_280,
        "num_heads": 20,
        "d_ff": 4 * 1280,
    },
    "gpt-2-xl": {
        "num_layers": 48,
        "d_model": 1_600,
        "num_heads": 25,
        "d_ff": 4 * 1600,
    },
}

In [8]:
def matmul_flops(m: int, n: int, p: int) -> int:
    return 2 * m * n * p


def calc_transformer_lm_flops(
    vocab_size: int,
    context_length: int,
    num_layers: int,
    d_model: int,
    num_heads: int,
    d_ff: int,
):
    rope_group_size = 2

    d_head = d_model // num_heads

    glu_flops = 2 * matmul_flops(context_length, d_model, d_ff) + matmul_flops(context_length, d_ff, d_model)

    qkv_proj_flops = 3 * matmul_flops(context_length, d_model, d_model)
    rope_flops = 2 * matmul_flops(
        context_length * num_heads, rope_group_size, rope_group_size
    )  # 2 * for queries and keys
    attn_flops = matmul_flops(num_heads * context_length, d_head, context_length) + matmul_flops(
        num_heads * context_length, context_length, d_head
    )
    output_proj_flops = matmul_flops(context_length, d_model, d_model)
    mha_flops = qkv_proj_flops + rope_flops + attn_flops + output_proj_flops

    layer_flops = mha_flops + glu_flops

    lm_head_flops = matmul_flops(context_length, d_model, vocab_size)

    total_flops = num_layers * layer_flops + lm_head_flops

    def print_teraflops(**kwargs):
        TERA = 10**12
        for name, flops in kwargs.items():
            print(f"- {name}={flops / TERA:.2f} ({flops / total_flops:.2%})")

    print_teraflops(total_flops=total_flops)

    print()

    print_teraflops(mha_flops=mha_flops * num_layers)
    print_teraflops(qkv_proj_flops=qkv_proj_flops * num_layers)
    print_teraflops(rope_flops=rope_flops * num_layers)
    print_teraflops(attn_flops=attn_flops * num_layers)
    print_teraflops(output_proj_flops=output_proj_flops * num_layers)

    print()

    print_teraflops(glu_flops=glu_flops * num_layers)

    print()

    print_teraflops(lm_head_flops=lm_head_flops)


for model_name, model_config in model_configs.items():
    print(f"*{model_name}*")
    calc_transformer_lm_flops(**general_config, **model_config)
    print("-" * 3)


*gpt-2-small*
- total_flops=0.35 (100.00%)

- mha_flops=0.10 (27.64%)
- qkv_proj_flops=0.04 (12.44%)
- rope_flops=0.00 (0.00%)
- attn_flops=0.04 (11.06%)
- output_proj_flops=0.01 (4.15%)

- glu_flops=0.17 (49.75%)

- lm_head_flops=0.08 (22.61%)
---
*gpt-2-medium*
- total_flops=1.03 (100.00%)

- mha_flops=0.31 (29.93%)
- qkv_proj_flops=0.15 (14.97%)
- rope_flops=0.00 (0.00%)
- attn_flops=0.10 (9.98%)
- output_proj_flops=0.05 (4.99%)

- glu_flops=0.62 (59.87%)

- lm_head_flops=0.11 (10.20%)
---
*gpt-2-large*
- total_flops=2.26 (100.00%)

- mha_flops=0.68 (29.96%)
- qkv_proj_flops=0.36 (16.05%)
- rope_flops=0.00 (0.00%)
- attn_flops=0.19 (8.56%)
- output_proj_flops=0.12 (5.35%)

- glu_flops=1.45 (64.20%)

- lm_head_flops=0.13 (5.84%)
---
*gpt-2-xl*
- total_flops=4.51 (100.00%)

- mha_flops=1.33 (29.44%)
- qkv_proj_flops=0.75 (16.73%)
- rope_flops=0.00 (0.00%)
- attn_flops=0.32 (7.14%)
- output_proj_flops=0.25 (5.58%)

- glu_flops=3.02 (66.91%)

- lm_head_flops=0.16 (3.65%)
---


In [11]:
print("**gpt-2-xl with large context length**")
calc_transformer_lm_flops(**{**general_config, "context_length": 16_384}, **model_configs["gpt-2-xl"])

**gpt-2-xl with large context length**
- total_flops=4.51 (100.00%)

- mha_flops=1.33 (29.44%)
- qkv_proj_flops=0.75 (16.73%)
- rope_flops=0.00 (0.00%)
- attn_flops=0.32 (7.14%)
- output_proj_flops=0.25 (5.58%)

- glu_flops=3.02 (66.91%)

- lm_head_flops=0.16 (3.65%)


In [4]:
def count_transformer_lm_params(vocab_size: int, num_layers: int, d_model: int, num_heads: int, d_ff: int, **kwargs):
    token_embeddings_params = vocab_size * d_model
    print(d_model, d_ff)
    glu_params = 3 * d_model * d_ff
    qkvo_proj_params = 4 * d_model * d_model
    lm_head_params = vocab_size * d_model

    total_params = token_embeddings_params + num_layers * (glu_params + qkvo_proj_params) + lm_head_params

    def print_params(**kwargs):
        BILLION = 2**30
        for name, params in kwargs.items():
            print(f"{name}={params / BILLION:.2f}B ({params / total_params:.2%})")

    print_params(total_params=total_params)
    print_params(token_embeddings_params=token_embeddings_params)
    print_params(glu_params=glu_params * num_layers)
    print_params(qkvo_proj_params=qkvo_proj_params * num_layers)
    print_params(lm_head_params=lm_head_params)


for model_name, model_config in model_configs.items():
    print(model_name)
    print("-" * 20)
    count_transformer_lm_params(**general_config, **model_config)
    print("-" * 100)


gpt-2-small
--------------------
768 3072
total_params=0.18B (100.00%)
token_embeddings_params=0.04B (20.27%)
glu_params=0.08B (44.60%)
qkvo_proj_params=0.03B (14.87%)
lm_head_params=0.04B (20.27%)
----------------------------------------------------------------------------------------------------
gpt-2-medium
--------------------
1024 4096
total_params=0.47B (100.00%)
token_embeddings_params=0.05B (10.18%)
glu_params=0.28B (59.73%)
qkvo_proj_params=0.09B (19.91%)
lm_head_params=0.05B (10.18%)
----------------------------------------------------------------------------------------------------
gpt-2-large
--------------------
1280 5120
total_params=1.00B (100.00%)
token_embeddings_params=0.06B (6.00%)
glu_params=0.66B (66.00%)
qkvo_proj_params=0.22B (22.00%)
lm_head_params=0.06B (6.00%)
----------------------------------------------------------------------------------------------------
gpt-2-xl
--------------------
1600 6400
total_params=1.98B (100.00%)
token_embeddings_params=0.07B (3.