In [128]:
import torch

# Count the number of params/flops in GPT

## I. Paramters

### I.1 count from model architecture

In [140]:
from collections import OrderedDict

In [None]:
config_args = {
    'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
    'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
    'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
    'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}['gpt2']

class GPTConfigs():
    def __init__(self, config_args):
        self.n_layer = config_args['n_layer']
        self.n_head = config_args['n_head']
        self.n_embd = config_args['n_embd']
        self.seq_length = 1024
        self.vocab_size = 50257
        self.bias = False         
        assert not self.bias, 'Assumes False bias for simplicity.'

nanoGPT_config = GPTConfigs(config_args)

In [145]:
class Params():
    def __init__(self, config_args):
        self.cfg = config_args

    def params_count(self):
        vocab_size = self.cfg.vocab_size
        n_embd = self.cfg.n_embd
        seq_length = self.cfg.seq_length
        n_layer = self.cfg.n_layer
        out = OrderedDict()

        # token and position embeddings
        out['embed_token'] = vocab_size * n_embd
        out['embed_posi'] = seq_length * n_embd
        out['embed'] = out['embed_token'] + out['embed_posi']

        # an attention block
        out['att_ln'] = n_embd # for gamma, no beta(bias is False in the nanoGPT)
        out['att_kqv'] = n_embd * n_embd * 3
        out['att_proj'] = n_embd * n_embd
        out['att'] = out['att_ln'] + out['att_kqv'] + out['att_proj']

        # a MLP block
        ff_size = 4 * n_embd
        out['mlp_ln'] = n_embd
        out['mlp_ff'] = n_embd * ff_size
        out['mlp_proj'] = ff_size * n_embd
        out['mlp'] = out['mlp_ln'] + out['mlp_ff'] + out['mlp_proj']

        # all key blocks
        out['block'] = out['att'] + out['mlp']
        out['total_blocks'] = n_layer * out['block']

        # other: final layer norm and linear
        out['ln_final'] = n_embd
        out['linear_final'] = 0 # uncounted because of weight sharing

        # total
        out['total'] = out['embed'] + out['total_blocks'] + out['ln_final']
        out['total_minus_embed'] = out['total'] - out['embed_posi']

        return out


In [143]:
p = Params(nanoGPT_config).params_count()
params_total = p['total']
params_no_embd = p['total_minus_embed']
print(f"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}")


we see: 124337664, expected: 124337664, match: True


In [144]:
for k,v in p.items():
    print(f"{k:20s} {v/1e6:10.1f}M {v/params_total*100:10.1f}%")

embed_token                38.6M       31.0%
embed_posi                  0.8M        0.6%
embed                      39.4M       31.7%
att_ln                      0.0M        0.0%
att_kqv                     1.8M        1.4%
att_proj                    0.6M        0.5%
att                         2.4M        1.9%
mlp_ln                      0.0M        0.0%
mlp_ff                      2.4M        1.9%
mlp_proj                    2.4M        1.9%
mlp                         4.7M        3.8%
block                       7.1M        5.7%
total_blocks               85.0M       68.3%
ln_final                    0.0M        0.0%
linear_final                0.0M        0.0%
total                     124.3M      100.0%
total_minus_embed         123.6M       99.4%


### I.2 count from checkpoint

In [134]:
# params are stored in fp32 
params_bytes = params_total * 4
# the AdamW optimizer has 2 additional buffers per param for statistics
param_opt_bytes = params_bytes + 2 * params_bytes

print(f'est checkpoint size: {param_opt_bytes}bytes = {param_opt_bytes/1e9:.2f}GB')

est checkpoint size: 1492051968bytes = 1.49GB


**in the buffers of AdamW**
- the AdamW optimizer has 2 additional buffers per param for statistics
- in the checkpoint of nanoGPT file, the params can be referred to as: 
  ```python
  ckpt = torch.load('../log/model_00050.pt')
  states = ckpt1['optimizer']['state']  # all params are saved here
  param0 = states[0]    # states of params can be referred by their order in the model
  param40 = states[40]  # each param has two states for 1st and 2nd moment est
  ```
- each parameter has its key(a number) in the `states`, each `states` has two copy of the value of the parameter

In [135]:
# word count the checkpoint file
measured = !wc -c ../log/model_00050.pt 
measured_bytes = int((measured[0].split(' '))[0])
print(f"ckpt.pt size: {measured_bytes}bytes = {measured_bytes/1e9:.2f}GB")

ckpt.pt size: 1493898466bytes = 1.49GB


In [136]:
# word count of a checkpoint file with an empty optimizer 
measured_no_opt = !wc -c ../log/model_00000.pt 
measured_no_opt_bytes = int((measured_no_opt[0].split(' '))[0])
print(f"ckpt.pt size: {measured_no_opt_bytes}bytes = {measured_no_opt_bytes/1e9:.2f}GB")

ckpt.pt size: 497963882bytes = 0.50GB


### I.3 GPU memory occupy

In [139]:
mem_4090 = 24e9
print(f"memory taken up for parameters: {param_opt_bytes / mem_4090 * 100:.2f}%")

memory taken up for parameters: 6.22%


## II. Flops

### II.1 count from model architecture

- only count Weight(Matrix Multiply) FLOPs, all others (LayerNorm, Softmax, etc) are effectively irrelevant.
  - because the computational complexity of matrix multiplication is O(n²m) for an n×m matrix, which scales much faster than operations like LayerNorm O(n) or Softmax O(n)
  - and Weight operations involve large matrix multiplications that can be highly optimized using parallel processing and specialized matrix multiplication units. but other operations often require more sequential processing and memory access
  - For most practical purposes, weight FLOPs give a good approximation of relative computational cost between models

In [None]:
class Flops():
    def __init__(self, config_args):
        self.cfg = config_args

    def flops_count(self):
        """
        count the flops for processing one sequence
        count actual FLOPs, not MACs(multiple add). Hence 2* all over the place
        matrix multiply X (BxC) @ Y (CxD) -> Z (BxD) flops are ~2*B*C*D
        the flops is for one sequence of seq_length length.
        """

        vocab_size = self.cfg.vocab_size
        n_embd = self.cfg.n_embd
        seq_length = self.cfg.seq_length
        n_layer = self.cfg.n_layer
        n_head = self.cfg.n_head        
        head_size = n_embd // n_head
        out = OrderedDict()

        ## attention block
        #  1) projection to K, Q, V
        out['att_kqv'] = 2 * seq_length * (n_embd * 3 * n_embd)
        #  2) QK scores
        out['att_qkscores'] = 2 * seq_length * (seq_length * n_embd)
        #  3) weighted average: (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
        out['att_alpha'] = 2 * n_head * seq_length * (seq_length * head_size)
        #  4) attention projection
        out['att_proj'] = 2 * seq_length * (n_embd * n_embd)
        out['att'] = sum(out['att_'+i] for i in ['kqv', 'qkscores', 'alpha', 'proj'])

        ## MLP block
        ff_size = 4 * n_embd
        out['mlp_ff'] = 2 * seq_length * n_embd * ff_size
        out['mlp_proj'] = 2 * seq_length * ff_size * n_embd
        out['mlp'] = out['mlp_ff'] + out['mlp_proj']

        ## all key blocks
        out['block'] = out['att'] + out['mlp']
        out['total_blocks'] = n_layer * out['block']

        # other: final linear
        out['linear_final'] = 2 * seq_length * n_embd * vocab_size

        # total
        out['forward_total'] = out['total_blocks'] + out['linear_final']
        out['backward_total'] = 2 * out['forward_total']
        out['total'] = out['forward_total'] + out['backward_total']

        return out



In [165]:
f = Flops(nanoGPT_config).flops_count()
flops_forward = f['forward_total']
print(f"{'name':20s} {'flops':14s} {'ratio (%)':10s}")
for k,v in f.items():
    print(f"{k:20s} {v/1e6:.0f} M {v/flops_forward*100:10.1f}%")

name                 flops          ratio (%) 
att_kqv              3624 M        1.2%
att_qkscores         1611 M        0.6%
att_alpha            1611 M        0.6%
att_proj             1208 M        0.4%
att                  8053 M        2.8%
mlp_ff               4832 M        1.7%
mlp_proj             4832 M        1.7%
mlp                  9664 M        3.3%
block                17717 M        6.1%
total_blocks         212601 M       72.9%
linear_final         79047 M       27.1%
forward_total        291648 M      100.0%
backward_total       583297 M      200.0%
total                874945 M      300.0%


### II.2 model flops utilization (MFU)

- 4090:
  - 4090_bf16_flops_promised = 165T   # with FP32 accumulate
  - 4090_tf32_flops_promised = 82.6T
  - 4090_fp32_flops_promised = 82.6T  # pytorch default
- 3090:
  - 3090_bf16_flops_promised = 71T    # with FP32 accumulate
  - 3090_tf32_flops_promised = 35.6T
  - 3090_fp32_flops_promised = 35.6T  # pytorch default
- A100:
  - A100_bf16_flops_promised = 312T   # with FP32 accumulate
  - A100_tf32_flops_promised = 156T
  - A100_fp32_flops_promised = 19.5T  # pytorch default

In [None]:
gpu_bf16_cap = {
    '3090': 71e12, # flops per second
    '4090': 165e12
}
gpu_tf32_cap = {
    '3090': 82.6e12,
    '4090': 35.6e12
}

In [182]:
total_bsize=589824 # grad_accum_steps * (B * T), measured the #tokens in a macro batch
B = 24
T = 1024
total_seq = total_bsize / 1024

dt = 4700 # measured by ms 
dt_per_seq = dt / total_seq # measured by ms
print(f'each sequence need {dt_per_seq:.2f} micro seconds')

each sequence need 8.16 micro seconds


In [185]:
flops_per_seq = f['total']
flops_per_second = flops_per_seq / dt_per_seq * 1000
print(f'about {flops_per_second/1e12:.2f} Tflops is done in one second')

about 107.23 Tflops is done in one second


In [188]:
flops_achieved = flops_per_second
gpu_cap = gpu_bf16_cap['4090'] 
fraction_of_gpu_used = flops_achieved / gpu_cap
print(f'The MFU (model flops utilization) of nanoGPT is: {fraction_of_gpu_used*100:.1f}%')


The MFU (model flops utilization) of nanoGPT is: 65.0%


### II.3 total cost(flops) of training

In [197]:
def model_time_cost(model_size, token_num, gpu, gpu_num, assumed_mfu):   
    flops_needed = 6 * model_size * token_num # est in the scaling law paper

    rtx4090_flops = gpu_bf16_cap[gpu] 
    flops_throughput = rtx4090_flops * assumed_mfu * gpu_num

    time_needed_s = flops_needed / flops_throughput
    return time_needed_s


In [198]:
model_size = params_no_embd
token_num = 300e9 # suppose 300B tokens
gpu = '4090'
gpu_num = 1
assumed_mfu = 0.5

seconds_cost = model_time_cost(model_size, token_num, gpu, gpu_num, assumed_mfu)
print(f'time needed to train the model: {seconds_cost/3600/24:.1f} days.')

time needed to train the model: 31.2 days.
