# Text Generation

Exploring raw text generation using llama-jax with and without KV caching.

# Setup

In [1]:
from sys import stdout
from time import time_ns, perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax import random

import llama_jax as ll

In [2]:
print(f"Available devices: {jax.devices()}")



Metal device set to: Apple M1 Max
Available devices: [METAL(id=0)]


I0000 00:00:1736123954.557401 17206389 service.cc:145] XLA service 0x6000012e2600 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736123954.557413 17206389 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1736123954.558573 17206389 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1736123954.558584 17206389 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


# Configure

In [3]:
# Parameters
max_tokens = 20
seed = time_ns()

In [4]:
def print_tps():
    duration = (timer() - start_time) / 1000000000
    tps = max_tokens / duration
    print(f"Generated {max_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")    

In [5]:
# Configure model
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B")
params = ll.checkpoint.load_parameters(config)
tokenizer = ll.checkpoint.load_tokenizer(config)
model = ll.model.create(config, params)

# Prompt

In [6]:
prompt = "I like traveling by train because"

# Generator

First, let's use the complete text generator pipeline.

In [7]:
key, subkey = random.split(key)
generator = ll.text.generator(config, model=model, key=subkey, max_tokens=max_tokens)

In [8]:
start_time = timer()

# Print input tokens
stdout.write(prompt)

# Generate tokens
for token in generator(prompt):
    # Print next token
    stdout.write(token)
    
_ = stdout.write("\n\n")

print_tps()

I like traveling by train because I can take my time and enjoy the view. I also like traveling by train because it’s more

Generated 20 tokens in 13.6 s (1.5 tps)


# Tokenize

In [9]:
token_ids = tokenizer.encode(prompt)
token_ids

Array([[128000,     40,   1093,  21646,    555,   5542,   1606]], dtype=int32)

# Without KV Cache

In [10]:
start_time = timer()

# Print input tokens
stdout.write(prompt)

# Process all tokens on first pass
x = token_ids

for _ in range(max_tokens):
    # Transform tokens into logits
    logits = ll.model.forward(config, model, x)
    
    # Sample next token
    next_token_id, key = ll.model.next_token(logits, key)
    
    # Print next token
    stdout.write(tokenizer.decode(next_token_id)[0])

    # Process all tokens on next pass
    x = jnp.concat([x, next_token_id], axis=-1)
    
_ = stdout.write("\n\n")

print_tps()

I like traveling by train because I have the chance to sit back and relax and read a book. The train is a good place

Generated 20 tokens in 14.0 s (1.4 tps)


# With KV Cache

In [11]:
start_time = timer()

# Print input tokens
stdout.write(prompt)

# Initialize cache
kv_cache = ll.kv_cache.create(config)

# Process all tokens on first pass
x = token_ids

for _ in range(max_tokens):
    # Transform tokens into logits
    logits, kv_cache = ll.model.forward(config, model, x, kv_cache=kv_cache)
    
    # Sample next token
    next_token_id, key = ll.model.next_token(logits, key)
    
    # Print next token
    stdout.write(tokenizer.decode(next_token_id)[0])

    # Process generated token on next pass
    x = next_token_id

_ = stdout.write("\n\n")

print_tps()

I like traveling by train because it gives me the opportunity to see the country and learn about the culture. I also like the fact

Generated 20 tokens in 2.1 s (9.7 tps)
