# Optimizing Throughput: JIT

We've shown the average throughput of llama-jax to be 10 tps short of Ollama (25 vs. 35 running on 64GB M1 MBP). The goal of this series of optimization experiments is to see how closely we can match Ollama's throughput.

## JIT

Jax's just in time compiler is key optimization tool. But where do we apply it? Here we compare tps of jit at the very top level, vs second level.

# Setup

In [None]:
from collections.abc import Callable, Iterator
from concurrent.futures import wait,ThreadPoolExecutor
from functools import partial
from queue import SimpleQueue
from time import time_ns as seed

import jax
from jax import Array, numpy as jnp, random
from jax.nn import softmax
import ollama
from pandas import DataFrame
from pydantic import BaseModel
import seaborn as sns
from tqdm.auto import tqdm

import llama_jax as ll
from llama_jax.chat import Message
from llama_jax.checkpoint import ModelConfig
from llama_jax.model import Model
from llama_jax.kvc import KVCache
from llama_jax.tools import default_arg

# Parameters

In [2]:
n_prompts = 20
max_tokens = 50

# Prompts

In [None]:
class PromptPool(BaseModel):
    prompts: list[str]

prompt = ll.tools.prompt(
    f"""
    Generate {n_prompts} LLM prompts. Each prompt should pose an interesting question about math and science in 3 to 6 words.
    Your response must be formatted as JSON.
    """
)

response = ollama.chat(
    model="llama3.2:3b", 
    messages=[{"role": "user", "content": prompt}], 
    format=PromptPool.model_json_schema(),
)

prompt_pool = PromptPool.model_validate_json(response.message.content)
prompts = prompt_pool.prompts

prompts

In [4]:
warmup_prompt = "Why is the sky blue?"

# Building Blocks

In [5]:
def transform(
    config: ModelConfig,
    state: Model,
    token_ids: Array,
    position_mask: Array,
    *,
    kvc: KVCache | None = None,
) -> Array | tuple[Array, KVCache]:
    """Transform token_ids into next token logits."""

    # Remember if cache was provided
    external_cache = kvc is not None

    # Defaults
    kvc = default_arg(kvc, default_factory=partial(ll.kvc.create, config))

    # Sanity check
    assert token_ids.ndim == 2

    # Map tokens to embeddings
    x = ll.embeddings.forward(config, state.embeddings, token_ids)

    # Create mask
    mask = ll.attention.attention_mask(config, position_mask)

    # Create mutable kv cache
    kvc_layers = list(kvc)

    # Apply layers
    for i, layer in enumerate(state.layers):
        x, kvc_layers[i] = ll.layer.forward(config, layer, state.rope, mask, x, kvc_layers[i])

    # Convert kv caches back into immutable sequence
    kvc = KVCache(kvc_layers)

    # Apply head
    x = ll.head.forward(config, state.head, x, position_mask)

    # Return updated cache if provided
    if external_cache:
        return x, kvc

    return x


def next_token_id(
    logits: Array,
    *,
    key: Array | None = None,
    temperature: float | None = None,
    top_k: int | None = None,
    top_p: float | None = None,
) -> Array:
    """Select next token id using temperature, top k, and top p sampling."""

    # Temperature
    # -----------

    # Defaults
    temperature = default_arg(temperature, 0.6)

    # If temperature is 0, return the top token
    if temperature == 0:
        return jnp.argmax(logits, axis=-1, keepdims=True)

    # Apply temperature
    logits = logits / temperature

    # Ranking
    # -------

    # Sort logits in descending order, maintaining original indices
    indices = jnp.argsort(logits, axis=-1, descending=True)

    # Top K
    # -----

    # Defaults
    top_k = default_arg(top_k, 50)

    # Apply top k to entire batch at once
    indices = indices[:, :top_k]
    logits = jnp.take_along_axis(logits, indices, axis=-1)

    # Top P
    # -----

    # Defaults
    top_p = default_arg(top_p, 0.9)

    # Convert remaining logits to probabilities
    probs = softmax(logits, axis=-1)

    # Find index where cumulative sum of probs exceeds p
    cumulative_mask = probs.cumsum(axis=-1) <= top_p
    cutoff = jnp.sum(cumulative_mask, axis=-1, keepdims=True)

    # Calculate mask for indicies <= cutoff
    mask = jnp.broadcast_to(jnp.arange(logits.shape[-1]), logits.shape) <= cutoff

    # Zero out logits above cutoff
    logits = jnp.where(mask, logits, 0)

    # Random Selection
    # ----------------

    assert key is not None

    # Randomly choose from remaining logits
    key, subkey = random.split(key)
    selected = random.categorical(subkey, logits, axis=-1)[:, None]

    # Map selected back to original logit indices
    next_token_id = jnp.take_along_axis(indices, selected, axis=-1)

    return next_token_id

In [6]:
config = ll.checkpoint.load_config("Llama3.2-3b-Instruct")
tokenizer = ll.checkpoint.load_tokenizer(config)
model = ll.model.create(config)

## Experiment Run

In [7]:
Generator = Callable[[str], Iterator[str]]

def run(pipeline_name: str, generator: Generator):

    metrics = []

    for i, prompt in enumerate(prompts):
        with ll.render.token_view(prompt=prompt) as tv:
            for token in generator(prompt):
                tv.add_token(token)
        
        metrics.append({"pipeline": pipeline_name, "prompt": i, "tps": tv.tps})
    
    return metrics

In [8]:
# Reset metrics
pipeline_metrics = []

# Pipeline 1: JIT Top Level

Compile pipeline as one giant jit function.

In [13]:
pipeline_name = "pipeline1"

@partial(jax.jit, static_argnames="config")
def predict(config, model, x, position_mask, kvc, key):
    logits, kvc = transform(config, model, x, position_mask, kvc=kvc)
    token_id = next_token_id(logits, key=key)

    return token_id, kvc

def pipeline(content: str) -> Iterator[str]:

    key = random.key(seed())

    prompt = ll.chat.render_prompt([Message(role="user", content=content)])
    token_ids, position_mask = tokenizer.encode(prompt)

    x = token_ids
    kvc = ll.kvc.create(config)
    key, *subkeys = random.split(key, max_tokens+1)

    for i in range(max_tokens):
        
        token_id, kvc = predict(config, model, x, position_mask, kvc, subkeys[i])

        yield tokenizer.decode(token_id)[0]

        x = token_id

## Warmup

In [None]:
jax.clear_caches()

for token in tqdm(pipeline(warmup_prompt), desc="Warmup", unit_scale=True):
    pass

## Run

In [None]:
metrics = run(pipeline_name, pipeline)
metrics

In [16]:
pipeline_metrics += metrics

# Pipeline 2: Transform / Next Token JIT

Apply jit to `transform` and `next_token_id` instead.

In [17]:
pipeline_name = "pipeline2"

_transform = jax.jit(transform, static_argnames="config")
_next_token_id = jax.jit(next_token_id)

def pipeline(content: str) -> Iterator[str]:

    key = random.key(seed())

    prompt = ll.chat.render_prompt([Message(role="user", content=content)])
    token_ids, position_mask = tokenizer.encode(prompt)

    x = token_ids
    kvc = ll.kvc.create(config)
    key, *subkeys = random.split(key, max_tokens+1)

    for i in range(max_tokens):
        
        logits, kvc = _transform(config, model, x, position_mask, kvc=kvc)
        token_id = _next_token_id(logits, key=subkeys[i])

        yield tokenizer.decode(token_id)[0]

        x = token_id

## Warmup

In [None]:
jax.clear_caches()

for token in tqdm(pipeline(warmup_prompt), desc="Warmup", unit_scale=True):
    pass

## Run

In [None]:
metrics = run(pipeline_name, pipeline)
metrics

In [20]:
pipeline_metrics += metrics

# Analysis

In [25]:
data = DataFrame(pipeline_metrics)

In [None]:
sns.histplot(data, x="tps", hue="pipeline", multiple="dodge")

In [None]:
# pipeline2
data[data.pipeline == "pipeline2"].tps.describe()

In [None]:
# pipeline3
data[data.pipeline == "pipeline3"].tps.describe()