# Configurations, Constants and Enums

In [27]:
from typing import Literal, Dict
from enum import Enum, IntEnum
from dataclasses import dataclass
from pprint import pprint


@dataclass
class GPTConfig:
    num_decoder_blocks: int = 12
    context_length: int = 1024
    n_embd: int = 768
    ffw_size: int = 3072  # note, this is 4 * n_embd
    n_head: int = 12
    vocab_size: int = 50257
    bias: Literal[False] = False

    def __post_init__(self) -> None:
        assert self.ffw_size == 4 * self.n_embd, "ffw_size must be 4 * n_embd"
        assert self.bias is False, "bias must be False in this experiment."


class GPT2ModelType(Enum):
    GPT2 = "gpt2"
    GPT2_MEDIUM = "gpt2-medium"
    GPT2_LARGE = "gpt2-large"
    GPT2_XL = "gpt2-xl"


class ByteUnits(IntEnum):
    B = 1  # Byte = 1 byte
    KB = 1000  # Kilobyte = 10^3 bytes
    MB = 1000 ** 2  # Megabyte = 10^6 bytes
    GB = 1000 ** 3  # Gigabyte = 10^9 bytes


class FloatingPointPrecision(IntEnum):
    FP32 = 4  # 32-bit floating-point, 4 bytes
    FP16 = 2  # 16-bit floating-point, 2 bytes
    BFLOAT16 = 2  # bfloat16, 16-bit, 2 bytes


class GPUMemory(Enum):
    A100_40GB = 40e9  # 40 GB for NVIDIA A100
    V100_16GB = 16e9  # 16 GB for NVIDIA V100
    V100_32GB = 32e9  # 32 GB for NVIDIA V100
    T4_16GB = 16e9  # 16 GB for NVIDIA T4
    P100_16GB = 16e9  # 16 GB for NVIDIA P100
    RTX4090_24GB = 24e9  # 24 GB for NVIDIA RTX 4090


class GPU:
    def __init__(self, name: str, flops: Dict[FloatingPointPrecision, float]) -> None:
        self.name = name
        self.flops = flops


class A100(GPU):
    def __init__(self) -> None:
        super().__init__("A100", {
            FloatingPointPrecision.FP32: 19.5e12,
            FloatingPointPrecision.FP16: 312e12,
            FloatingPointPrecision.BFLOAT16: 312e12
        })


class RTX4090(GPU):
    def __init__(self) -> None:
        super().__init__("RTX 4090", {
            FloatingPointPrecision.FP32: 82.6e12,
            FloatingPointPrecision.FP16: 165.2e12,
            FloatingPointPrecision.BFLOAT16: 165.2e12
        })

In [3]:
gpt2_config = GPTConfig()
pprint(gpt2_config)

GPTConfig(num_decoder_blocks=12,
          context_length=1024,
          n_embd=768,
          ffw_size=3072,
          n_head=12,
          vocab_size=50257,
          bias=False)


# Total Trainable Parameters

In [14]:
import torch
from transformers import GPT2LMHeadModel
from collections import OrderedDict
import pandas as pd
from tabulate import tabulate

In [8]:
def total_trainable_parameters(model: torch.nn.Module, include_bias: bool = True) -> int:
    """Returns the number of trainable parameters in the model."""
    if not include_bias:
        return sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and "bias" not in name)
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [21]:
model_path = '/nfs/home/xiaoxiao/models/hf_models/gpt2'

gpt2 = GPT2LMHeadModel.from_pretrained(model_path)

gpt2_params_no_bias = total_trainable_parameters(gpt2, include_bias=False)
gpt2_params_with_bias = total_trainable_parameters(gpt2, include_bias=True)

print(
    f"Number of trainable parameters in GPT2 model: {gpt2_params_no_bias} (excluding bias) and {gpt2_params_with_bias} (including bias)."
)

Number of trainable parameters in GPT2 model: 124337664 (excluding bias) and 124439808 (including bias).


In [10]:
def params(
        num_decoder_blocks: int = 12,
        context_length: int = 1024,
        n_embd: int = 768,
        ffw_size: int = 3072,
        vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = n_embd * context_length
    out["embedding/token"] = n_embd * vocab_size
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = n_embd  # note, bias=False in our LN
    out["attention/kqv"] = n_embd * 3 * n_embd
    out["attention/proj"] = n_embd ** 2
    out["attention"] = out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]

    # MLP blocks
    assert ffw_size == 4 * n_embd, "ffw_size must be 4 * n_embd"
    out["mlp/ln"] = n_embd
    out["mlp/ffw"] = n_embd * ffw_size
    out["mlp/proj"] = ffw_size * n_embd
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["ln_f"] = n_embd  # final layernorm
    out["dense"] = 0  # 0 because of parameter sharing. This layer uses the weights from the embedding layer

    # total
    out["total"] = out["embedding"] + out["transformer"] + out["ln_f"] + out["dense"]

    return out

In [22]:
params_dict = params()
gpt2_params_no_bias_manual = params_dict["total"]

# Compare to expected PyTorch model parameter count
expected_params = gpt2_params_no_bias
comparison_result = gpt2_params_no_bias_manual == expected_params
comparison_msg = f"We see: {gpt2_params_no_bias_manual}, Expected: {expected_params}, Match: {comparison_result}"

data = {
    "Name": params_dict.keys(),
    "Parameters": params_dict.values(),
    "Ratio (%)": [value / gpt2_params_no_bias_manual * 100 for value in params_dict.values()],
}
df = pd.DataFrame(data)

# Printing comparison result and parameter distribution table
print(comparison_msg + "\n")
print(tabulate(df, headers="keys", tablefmt="pretty", showindex=False, numalign="right", floatfmt=".4f"))

We see: 124337664, Expected: 124337664, Match: True

+--------------------+------------+-----------------------+
|        Name        | Parameters |       Ratio (%)       |
+--------------------+------------+-----------------------+
| embedding/position |   786432   |  0.6324970042866496   |
|  embedding/token   |  38597376  |  31.042384711361475   |
|     embedding      |  39383808  |  31.674881715648123   |
|    attention/ln    |    768     | 0.0006176728557486812 |
|   attention/kqv    |  1769472   |  1.4231182596449616   |
|   attention/proj   |   589824   |  0.47437275321498723  |
|     attention      |  2360064   |  1.8981086857156975   |
|       mlp/ln       |    768     | 0.0006176728557486812 |
|      mlp/ffw       |  2359296   |   1.897491012859949   |
|      mlp/proj      |  2359296   |   1.897491012859949   |
|        mlp         |  4719360   |   3.795599698575646   |
|       block        |  7079424   |   5.693708384291344   |
|    transformer     |  84953088  |   68.324500

# Calculating Checkpoint Size and Fluff Ratio

In [23]:
def calculate_checkpoint_size(params_count: int, precision: FloatingPointPrecision, units: ByteUnits) -> float:
    """
    Calculate the estimated checkpoint size in specified units.

    This function estimates the checkpoint size for a model given the number
    of parameters, the precision of these parameters, and
    the desired units for the result. It accounts for the AdamW optimizer's
    storage requirements by adding two times the parameter bytes to account
    for the optimizer's moment and velocity vectors.

    Parameters
    ----------
    params_count : int
        The number of parameters excluding biases.
    precision : FloatingPointPrecision
        The floating point precision of the parameters.
    units : ByteUnits
        The units for the resulting checkpoint size.

    Returns
    -------
    float
        The estimated checkpoint size in the specified units.

    Notes
    -----
    The AdamW optimizer requires additional storage for each parameter
    for maintaining momentum and variance vectors, hence the calculation
    includes 2 * params_bytes to accommodate these.
    """
    params_bytes = params_count * precision.value
    params_and_buffers_bytes = params_bytes + 2 * params_bytes  # AdamW optimizer buffers
    return params_and_buffers_bytes / units.value


def calculate_fluff_ratio(measured_bytes: int, estimated_bytes: float, units: ByteUnits) -> float:
    """
    Calculate the fluff ratio between measured and estimated checkpoint sizes.

    The fluff ratio is a measure of the overhead or additional data in the
    checkpoint file, expressed as a percentage of the estimated size. This
    function converts the estimated size from gigabytes (or specified units)
    to bytes before calculating the ratio to ensure consistency in units.

    Parameters
    ----------
    measured_bytes : int
        The actual size of the checkpoint file, in bytes.
    estimated_bytes : float
        The estimated size of the checkpoint file, in the specified units.
    units : ByteUnits
        The units in which the estimated bytes are provided.

    Returns
    -------
    float
        The fluff ratio, expressed as a percentage.
    """
    estimated_bytes_in_bytes = estimated_bytes * units.value
    return (measured_bytes / estimated_bytes_in_bytes) * 100

In [33]:
gpt2_checkpoint_size_measured_in_bytes = 1542470366  # from 'wc -c ckpt.pt'
gpt2_checkpoint_size_measured_in_gb = gpt2_checkpoint_size_measured_in_bytes / ByteUnits.GB

gpt2_checkpoint_size_estimated_in_bytes = calculate_checkpoint_size(
    params_count=gpt2_params_no_bias,
    precision=FloatingPointPrecision.FP32,
    units=ByteUnits.B,
)
gpt2_checkpoint_size_estimated_in_gb = gpt2_checkpoint_size_estimated_in_bytes / ByteUnits.GB

fluff_ratio = calculate_fluff_ratio(
    measured_bytes=gpt2_checkpoint_size_measured_in_bytes,
    estimated_bytes=gpt2_checkpoint_size_estimated_in_bytes,
    units=ByteUnits.B,
)

data = [
    ["Measured Checkpoint Size (bytes)", gpt2_checkpoint_size_measured_in_bytes],
    ["Measured Checkpoint Size (GB)", gpt2_checkpoint_size_measured_in_gb],
    ["Estimated Checkpoint Size (bytes)", gpt2_checkpoint_size_estimated_in_bytes],
    ["Estimated Checkpoint Size (GB)", gpt2_checkpoint_size_estimated_in_gb],
    ["Fluff Ratio", fluff_ratio],
]

print(tabulate(data, headers=["Metric", "Value"], tablefmt="pretty"))

+-----------------------------------+-------------------+
|              Metric               |       Value       |
+-----------------------------------+-------------------+
| Measured Checkpoint Size (bytes)  |    1542470366     |
|   Measured Checkpoint Size (GB)   |    1.542470366    |
| Estimated Checkpoint Size (bytes) |   1492051968.0    |
|  Estimated Checkpoint Size (GB)   |    1.492051968    |
|            Fluff Ratio            | 103.3791314968461 |
+-----------------------------------+-------------------+


# GPU Memory Footprint of Loading Model and Optimizer

In [28]:
def calculate_memory_ratio(checkpoint_size: float, gpu_memory: GPUMemory) -> str:
    memory_ratio = checkpoint_size / gpu_memory.value * 100
    return f"Memory ratio taken up just for parameters: {memory_ratio:.2f}%"


print(
    calculate_memory_ratio(checkpoint_size=gpt2_checkpoint_size_estimated_in_bytes, gpu_memory=GPUMemory.RTX4090_24GB))

Memory ratio taken up just for parameters: 6.22%


# Estimating FLOPs for a Single Forward Pass

In [29]:
def flops(
        num_decoder_blocks: int = 12,
        context_length: int = 1024,
        n_embd: int = 768,
        n_head: int = 12,
        ffw_size: int = 3072,
        vocab_size: int = 50257,
) -> OrderedDict[str, int]:
    # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = n_embd // n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * context_length * (n_embd * 3 * n_embd)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * context_length * context_length * n_embd
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * n_head * (context_length * context_length * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * context_length * (n_embd * n_embd)
    out["attention"] = sum(out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"])

    # MLP blocks
    ffw_size = 4 * n_embd  # feed forward size
    out["mlp/ffw1"] = 2 * context_length * (n_embd * ffw_size)
    out["mlp/ffw2"] = 2 * context_length * (ffw_size * n_embd)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = num_decoder_blocks * out["block"]
    out["dense"] = 2 * context_length * (n_embd * vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["dense"]
    out["backward_total"] = 2 * out["forward_total"]  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out

In [30]:
f = flops()
flops_total = f["forward_total"]

table = [("name", "flops", "ratio (%)")]
for k, v in f.items():
    table.append((k, v, v / flops_total * 100))

print(tabulate(table, headers="firstrow", tablefmt="pretty", numalign="right"))

+------------------+--------------+---------------------+
|       name       |    flops     |      ratio (%)      |
+------------------+--------------+---------------------+
|  attention/kqv   |  3623878656  | 1.2425508965889174  |
| attention/scores |  1610612736  | 0.5522448429284077  |
| attention/reduce |  1610612736  | 0.5522448429284077  |
|  attention/proj  |  1207959552  | 0.41418363219630583 |
|    attention     |  8053063680  | 2.7612242146420387  |
|     mlp/ffw1     |  4831838208  | 1.6567345287852233  |
|     mlp/ffw2     |  4831838208  | 1.6567345287852233  |
|       mlp        |  9663676416  | 3.3134690575704466  |
|      block       | 17716740096  |  6.074693272212485  |
|   transformer    | 212600881152 |  72.89631926654981  |
|      dense       | 79047426048  |  27.10368073345018  |
|  forward_total   | 291648307200 |        100.0        |
|  backward_total  | 583296614400 |        200.0        |
|      total       | 874944921600 |        300.0        |
+-------------

# Model FLOPs Utilization (MFU)

In [31]:
# here is what we currently roughly measure
batch_size = 20 * 5  # 5 is grad_accum, so total batch size is 100
measured_time = 0.755  # in seconds per iteration
measured_throughput = batch_size / measured_time # number of samples processed per second
flops_achieved_per_second = f["total"] * measured_throughput

# A100 is cited to be 312 TFLOPS of bfloat16 running on tensor cores
a100_bfloat16_promised_flops = 312e12

# the fraction of the A100 that we are using:
print(f"fraction of A100 used: {flops_achieved_per_second / a100_bfloat16_promised_flops * 100:.2f}%")

fraction of A100 used: 37.14%


# Theoretical FLOPs in Transformer Models

In [32]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
N = params()["total"]  # this is number of parameters, N
D = 300e9  # 300B tokens, this is dataset size in tokens, D
a100_bfloat16_promised_flops = 312e12  # 312 TFLOPS
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
flops_throughput = a100_bfloat16_promised_flops * 8 * assumed_mfu  # assume an 8XA100 node at 30% utilization
flops_needed = 6 * N * D  # 6ND
time_needed_over_all_tokens_in_seconds = flops_needed / flops_throughput  # in seconds
print(f"time needed to train the model: {time_needed_over_all_tokens_in_seconds/3600/24:.2f} days")

time needed to train the model: 3.46 days
