In [1]:
import numpy as np

In [63]:
vocab_size = 50257
num_layers = 48
d_model = 1600
seq_len = 1024
num_heads = 25
d_k = d_model // num_heads
d_ff = 6400 # normallyd_model * 8/3 

In [64]:
num_matmul_flops = 0 

attn_block_flops = 0
rope_flops = 0
mha_rope_flops = 0
pw_feed_forward_flops = 0

output_embedding_flops = 0
# Embedding
num_matmul_flops += 0

for i in range(num_layers):
    checkpoint1_num_matmul_flops = num_matmul_flops
    ###### START Tranformer block ######

    # RMSNorm 1
    num_matmul_flops += 0

    # MHA with RoPE
    checkpoint2_num_matmul_flops = num_matmul_flops
    # 3 matric multiplies for Q, K, V: "... batch seq d_model, h_d_k d_model -> ... batch seq h_d_k"
    num_matmul_flops += 3 * (2 * seq_len * d_model * (num_heads*d_k))

    # RoPE for Q and K for num_heads: "... seq d_k1 d_k2, ... seq d_k2 -> ... seq d_k1"
    num_matmul_flops += num_heads * (2 * (2 * d_k * d_k * seq_len))
    rope_flops += num_heads * (2 * (2 * d_k * d_k * seq_len))

    # scaled_dot_product_attention 1 for num_heads: "... seq d_k, ... seq d_k -> ... seq seq"
    num_matmul_flops += num_heads * (2 * seq_len * d_k * seq_len)

    # scaled_dot_product_attention 2 for num_heads: "... seq seq, ... seq d_k -> ... seq d_k"
    num_matmul_flops += num_heads * (2 * seq_len * seq_len * d_k)

    # MHA output projection: "... batch seq h_d_k, d_model h_d_k-> ... batch seq d_model"
    num_matmul_flops += 2 * seq_len * (num_heads*d_k) * d_model

    mha_rope_flops += num_matmul_flops - checkpoint2_num_matmul_flops

    # RMSNorm 2
    num_matmul_flops += 0

    # PositionwiseFeedforward: 
    # "batch seq d_model, d_ff d_model -> batch seq d_ff"
    # "batch seq d_model, d_ff  d_model -> batch seq d_ff"
    # "batch seq d_ff, d_model  d_ff -> batch seq d_model"
    num_matmul_flops += 3 * (2 * seq_len * d_model * d_ff)
    pw_feed_forward_flops += 3 * (2 * seq_len * d_model * d_ff)

    ###### END Tranformer block ######
    attn_block_flops += num_matmul_flops - checkpoint1_num_matmul_flops
    
# Last RMSNorm "batch seq d_ff -> batch seq d_ff"
num_matmul_flops += 0

# Linear projection: "batch seq d_ff, vocab_size d_ff-> batch seq vocab_size"
num_matmul_flops += 2 * seq_len * d_model * vocab_size
output_embedding_flops += 2 * seq_len * d_model * vocab_size
# Softmax
num_matmul_flops += 0

print(f'\\noindent Total FLOPs: {"{:e}".format(num_matmul_flops).format()} \\\\')
print()
print(f'\\noindent Proportion of FLOPs in Attention Blocks: {round(100*attn_block_flops/num_matmul_flops, 2)}\\% \\\\')
print(f'\\noindent Proportion of FLOPs for Output Embedding: {round(100*output_embedding_flops/num_matmul_flops, 2)}\\% \\\\')
print()
print(f'\\noindent Proportion of FLOPs for RoPE: {round(100*rope_flops/attn_block_flops, 2)}\\% \\\\')
print(f'\\noindent Proportion of FLOPs for Multi-Head Attention with RoPE: {round(100*mha_rope_flops/attn_block_flops, 2)}\\% \\\\')
print(f'\\noindent Proportion of FLOPs for Position-Wise Feed-Forward: {round(100*pw_feed_forward_flops/attn_block_flops, 2)}\\% \\\\')


\noindent Total FLOPs: 4.533469e+12 \\

\noindent Proportion of FLOPs in Attention Blocks: 96.37\% \\
\noindent Proportion of FLOPs for Output Embedding: 3.63\% \\

\noindent Proportion of FLOPs for RoPE: 0.46\% \\
\noindent Proportion of FLOPs for Multi-Head Attention with RoPE: 30.88\% \\
\noindent Proportion of FLOPs for Position-Wise Feed-Forward: 69.12\% \\


In [65]:
149844918272000/1e14

1.49844918272

## Count parameters

In [68]:
num_parameters = 0
# Embedding
num_parameters += vocab_size * d_model

for i in range(num_layers):
    ###### START Tranformer block ######

    # RMSNorm 1
    num_parameters += d_model

    # MHA with RoPE
    # 3 matric multiplies for Q, K, V: "... batch seq d_model, h_d_k d_model -> ... batch seq h_d_k"
    num_parameters += num_heads * (d_model / num_heads) * d_model
    num_parameters += num_heads * (d_model / num_heads) * d_model
    num_parameters += num_heads * (d_model / num_heads) * d_model

    # RoPE for Q and K for num_heads: "... seq d_k1 d_k2, ... seq d_k2 -> ... seq d_k1"
    num_parameters += 0

    # MHA output projection: "... batch seq h_d_k, d_model h_d_k-> ... batch seq d_model"
    num_parameters += (d_model / num_heads) * num_heads * d_model

    # RMSNorm 2
    num_parameters += d_model

    # PositionwiseFeedforward: 
    # "batch seq d_model, d_ff d_model -> batch seq d_ff"
    # "batch seq d_model, d_ff  d_model -> batch seq d_ff"
    # "batch seq d_ff, d_model  d_ff -> batch seq d_model"
    num_parameters += d_model * d_ff
    num_parameters += d_ff * d_model
    num_parameters += d_model * d_ff

    ###### END Tranformer block ######
    
# Last RMSNorm "batch seq d_model -> batch seq d_model"
num_parameters += d_model

# Linear projection: "batch seq d_model, vocab_size d_model-> batch seq vocab_size"
num_parameters += vocab_size * d_model

# Softmax
num_parameters += 0
print(f'\\noindent Total trainable parameters: {"{:e}".format(num_parameters).format()} \\\\')

\noindent Total trainable parameters: 2.127058e+09 \\


In [70]:
(num_parameters * 4) / 2**30

7.923907041549683