# LLM inference Speed of Light measurements

This document will estimate speed of light for LLM inference, assuming a transformer autoregressive architecture. We will only look at the token generation phase (and ignore prompt processing).

In [26]:
# for consistency we will assume HF-compatible hyper-parameter names
class Model():
    def __init__(self, name, hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, num_hidden_layers, vocab_size, num_ffn=3, tied_embeddings=False):
        self.name = name
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.num_hidden_layers = num_hidden_layers
        self.vocab_size = vocab_size
        self.num_ffn = num_ffn
        self.tied_embeddings = tied_embeddings

    def params(self):
        # for simplicity we omit normalization layers and biases
        dim = self.hidden_size
        kvdim = dim // self.num_attention_heads * self.num_key_value_heads
        res = 0
        res += self.vocab_size * dim # embed
        res += self.num_hidden_layers * dim * dim # attn query
        res += self.num_hidden_layers * dim * kvdim * 2 # attn key, value
        res += self.num_hidden_layers * dim * dim # attn output
        res += self.num_hidden_layers * dim * self.intermediate_size * self.num_ffn # ffn
        if not self.tied_embeddings:
            res += self.vocab_size * dim # output
        return res

    def __str__(self):
        return f'{self.name} {self.params()/1e9:.1f}B'

In [30]:
# let's declare a few model configurations and print their sizes
llama7b = Model('Llama-7b', 4096, 11008, 32, 32, 32, 32000)
llama70b = Model('Llama-70b', 8192, 28672, 64, 8, 80, 32000)
mistral7b = Model('Mistral-7b', 4096, 14336, 32, 8, 32, 32000)
opt7b = Model('opt-7b', 4096, 16384, 32, 32, 32, 50272, num_ffn=2, tied_embeddings=True)
phi2 = Model('phi-2', 2560, 10240, 32, 32, 32, 51200, num_ffn=2)

str(llama7b), str(llama70b), str(mistral7b), str(opt7b), str(phi2)

('Llama-7b 6.7B',
 'Llama-70b 69.0B',
 'Mistral-7b 7.2B',
 'opt-7b 6.6B',
 'phi-2 2.8B')

In [33]:
# when running model inference, after every token we store intermediate k/v in kv cache
# note that the use of GQA (group query attention) significantly reduces kv cache size
def kvcache(model, seq_len, batch_size):
    dim = model.hidden_size
    kvdim = dim // model.num_attention_heads * model.num_key_value_heads
    return batch_size * model.num_hidden_layers * seq_len * kvdim * 2

kvcache(llama7b, 4096, 1), kvcache(mistral7b, 4096, 1)

(1073741824, 268435456)

In [37]:
# when running inference in auto-regressive mode, we need to read both model and kv cache
def bandwidth(model, seq_len, batch_size, modelbits=16, cachebits=16):
    return int(model.params() * modelbits / 8) + int(kvcache(model, seq_len, batch_size) * cachebits / 8)

# we also need to do computations on both model and kv cache
# every element we read from either will use 2 flops as it participates in one FMA
# (there are a few other ops like softmax that aren't significant for large models)
def flops(model, seq_len, batch_size):
    return model.params() * batch_size * 2 + kvcache(model, seq_len, batch_size) * 2

In [117]:
bs = 16
bs/(bandwidth(llama7b, 2048, bs) / 1935e9)

1009.9109603207517

In [120]:
bs = 16
bs/(flops(llama7b, 2048, bs) / 19e12)

1305.8382827701607