# Text Generation

Exploring raw text generation using llama-jax with and without KV caching.

# Setup

In [1]:
from time import time_ns as seed

from jax import numpy as jnp
from jax import random

import llama_jax as ll



Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



I0000 00:00:1741022152.168361 11834044 service.cc:145] XLA service 0x11baf3f00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741022152.168371 11834044 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1741022152.169794 11834044 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1741022152.169807 11834044 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


# Configure

In [2]:
# Parameters
max_tokens = 20

In [3]:
key = random.key(seed())

# Configure model
config = ll.checkpoint.load_config("Llama3.2-3B")
tokenizer = ll.checkpoint.load_tokenizer(config)
model = ll.model.create(config)

# Prompt

In [4]:
prompt = "I like traveling by train because"

In [5]:
token_ids, position_mask = tokenizer.encode(prompt)
token_ids

Array([[128000,     40,   1093,  21646,    555,   5542,   1606]], dtype=int32)

# Without KV Cache

In [6]:
with ll.render.token_view(config, prompt=prompt) as tv:
    
    key, *token_keys = random.split(key, max_tokens+1)

    # Process entire sequence on first pass
    x = token_ids
    
    # Generate max_tokens
    for i in range(max_tokens):
        
        # Transform token ids into logits
        logits = ll.model.forward(config, model, x, position_mask)
        
        # Sample next token
        next_token_id = ll.model.next_token(logits, key=token_keys[i])
        
        # Decode and collect
        tv.add_token(tokenizer.decode(next_token_id)[0])

        # Process all tokens on next pass (no kvc)
        x = jnp.concat([x, next_token_id], axis=-1)

Output()

# With KV Cache

In [7]:
with ll.render.token_view(config, prompt=prompt) as tv:

    key, *token_keys = random.split(key, max_tokens+1)

    # Initialize kvc
    kvc = ll.kvc.create(config)
    
    # Process entire sequence on first pass
    x = token_ids
    
    # Generate max_tokens
    for i in range(max_tokens):

        # Transform token ids into logits
        logits, kvc = ll.model.forward(config, model, x, position_mask, kvc=kvc)
        
        # Sample next token
        next_token_id = ll.model.next_token(logits, key=token_keys[i])
                
        # Decode and collect
        tv.add_token(tokenizer.decode(next_token_id)[0])
    
        # Process generated token on next pass
        x = next_token_id

Output()