# Text Generation

Exploring raw text generation using llama-jax with and without KV caching.

# Setup

In [1]:
from sys import stdout
from time import time_ns

import jax
from jax import random

import llama_jax as ll

In [2]:
print(f"Available devices: {jax.devices()}")



Metal device set to: Apple M1 Max
Available devices: [METAL(id=0)]


I0000 00:00:1736027980.560877 16189038 service.cc:145] XLA service 0x600000a50000 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736027980.560886 16189038 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1736027980.562339 16189038 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1736027980.562348 16189038 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


# Configure

In [3]:
# Parameters
max_tokens = 20
seed = time_ns()

In [4]:
# Configure model
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B")
params = ll.checkpoint.load_parameters(config)
tokenizer = ll.checkpoint.load_tokenizer(config)
model = ll.model.create(config, params)

# Prompt 0

In [5]:
prompt = "I like traveling by train because"

## Tokenize

In [6]:
tokens = tokenizer.encode(prompt)
tokens

Array([[128000,     40,   1093,  21646,    555,   5542,   1606]], dtype=int32)

## Without KV Cache

In [9]:
# Print input tokens
stdout.write(prompt)

# Process all tokens on first pass
x = tokens

for _ in range(max_tokens):
    # Transform tokens into logits
    logits = ll.model.forward(config, model, x)
    
    # Sample next token
    next_token, key = ll.model.sample_tokens(logits, key)
    
    # Print next token
    stdout.write(tokenizer.decode(next_token)[0])

    # Process all tokens on next pass
    x = jnp.concat([x, next_token], axis=-1)

_ = stdout.write("\n")

I like traveling by train because it is so relaxing and there is plenty of time to do things like read or just look out the


## With KV Cache

In [10]:
# Print input tokens
stdout.write(prompt)

# Initialize cache
kv_cache = ll.kv_cache.create(config)

# Process all tokens on first pass
x = tokens

for _ in range(max_tokens):
    # Transform tokens into logits
    logits, kv_cache = ll.model.forward(config, model, x, kv_cache=kv_cache)
    
    # Sample next token
    next_token, key = ll.model.sample_tokens(logits, key)
    
    # Print next token
    stdout.write(tokenizer.decode(next_token)[0])

    # Process generated token on next pass
    x = next_token

_ = stdout.write("\n")

I like traveling by train because it’s more relaxing than flying. I can read and enjoy the scenery. There is also a lot
