# Text Generation

Exploring raw text generation using llama-jax with and without KV caching.

# Setup

In [11]:
from contextlib import contextmanager
from sys import stdout
from time import time_ns, perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax import random

import llama_jax as ll

In [12]:
print(f"Available devices: {jax.devices()}")

Available devices: [METAL(id=0)]


In [13]:
@contextmanager
def tps_report(n_tokens: int):
    start_time = timer()

    yield
    
    stdout.write("\n\n")
    
    duration = (timer() - start_time) / 1000000000
    tps = n_tokens / duration
    
    print(f"Generated {n_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")

# Configure

In [14]:
# Parameters
max_tokens = 20
seed = time_ns()

In [15]:
# Configure model
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B")
params = ll.checkpoint.load_parameters(config)
tokenizer = ll.checkpoint.load_tokenizer(config)
model = ll.model.create(config, params)

# Prompt

In [16]:
prompt = "I like traveling by train because"

# Without KV Cache

In [17]:
token_ids, position_mask = tokenizer.encode(prompt)
token_ids, position_mask

(Array([[128000,     40,   1093,  21646,    555,   5542,   1606]], dtype=int32),
 Array([[1, 1, 1, 1, 1, 1, 1]], dtype=int32, weak_type=True))

In [18]:
with tps_report(max_tokens):    
    
    # Print input tokens
    stdout.write(prompt)
    
    # Process all tokens on first pass
    x = token_ids
    
    for _ in range(max_tokens):
        # Transform tokens into logits
        logits = ll.model.forward(config, model, x, position_mask)
        
        # Sample next token
        key, subkey = random.split(key)
        next_token_id = ll.model.next_token(logits, key=subkey)
        
        # Print next token
        stdout.write(tokenizer.decode(next_token_id)[0])
    
        # Process all tokens on next pass
        x = jnp.concat([x, next_token_id], axis=-1)
        position_mask = ll.model.increment_position_mask(position_mask)

I like traveling by train because it’s a great way to see the country. I’ve been to the States, the UK and

Generated 20 tokens in 1.4 s (14.5 tps)


# With KV Cache

In [19]:
token_ids, position_mask = tokenizer.encode(prompt)
token_ids, position_mask

(Array([[128000,     40,   1093,  21646,    555,   5542,   1606]], dtype=int32),
 Array([[1, 1, 1, 1, 1, 1, 1]], dtype=int32, weak_type=True))

In [20]:
with tps_report(max_tokens):    
    
    # Print input tokens
    stdout.write(prompt)
    
    # Initialize cache
    kv_cache = ll.kv_cache.create(config)
    
    # Process all tokens on first pass
    x = token_ids
    
    for _ in range(max_tokens):
        # Transform tokens into logits
        logits, kv_cache = ll.model.forward(config, model, x, position_mask, kv_cache=kv_cache)
        
        # Sample next token
        key, subkey = random.split(key)
        next_token_id = ll.model.next_token(logits, key=subkey)
        
        # Print next token
        stdout.write(tokenizer.decode(next_token_id)[0])
    
        # Process generated token on next pass
        x = next_token_id
        position_mask = ll.model.increment_position_mask(position_mask)        

I like traveling by train because it gives me time to think. I don’t like traveling by bus because the driver is always talking

Generated 20 tokens in 0.9 s (23.4 tps)
