# Text Completions

Exploring text completions using llama-jax.

# Setup

In [1]:
from collections.abc import Sequence
from contextlib import contextmanager
from sys import stdout
from time import time_ns, perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax import random
from rich.live import Live
from rich.table import Table, Column

import llama_jax as ll

In [2]:
print(f"Available devices: {jax.devices()}")



Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

Available devices: [METAL(id=0)]


I0000 00:00:1740513059.967111 4297108 service.cc:145] XLA service 0x126546370 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1740513059.967118 4297108 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1740513059.968517 4297108 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1740513059.968532 4297108 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


In [3]:
@contextmanager
def tps_report(n_tokens: int):
    start_time = timer()

    yield
    
    stdout.write("\n\n")
    
    duration = (timer() - start_time) / 1000000000
    tps = n_tokens / duration
    
    print(f"Generated {n_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")


def render(content: str | Sequence[str]):
    if isinstance(content, str):
        content = [content]

    n = len(content)

    table = Table(show_header=False, expand=True)
    
    for _ in range(n):
        table.add_column(Column(), ratio=1/n)

    table.add_row(*content)
    
    return table    

# Configure

In [4]:
# Parameters
max_tokens = 30
seed = time_ns()

In [5]:
# Configure generator
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B")
key, subkey = random.split(key)
generator = ll.text.generator(config, key=subkey, max_tokens=max_tokens)

# Warmup

In [6]:
prompt = "My name is Julien and I like to"

In [7]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

My name is Julien and I like to play with light. I am a photographer based in Paris, France. I work as a professional photographer and I also do some commercial work for different brands

Generated 30 tokens in 26.6 s (1.1 tps)


# Prompt 0: Trains

In [8]:
prompt = "I like traveling by train because"
content = prompt

In [None]:
with tps_report(max_tokens):
    with Live(render(content)) as live:    
        
        # Generate tokens
        for token in generator(prompt):
            content += token
            live.update(render(content))

# Prompt 1: Paris

In [10]:
prompt = "Paris is an amazing place to visit,"
content = prompt

In [None]:
with tps_report(max_tokens):
    with Live(render(content)) as live:    
        
        # Generate tokens
        for token in generator(prompt):
            content += token
            live.update(render(content))

# Prompt 2: Once Upon a Time

In [12]:
prompt = "Once upon a time"
content = prompt

In [None]:
with tps_report(max_tokens):
    with Live(render(content)) as live:    
        
        # Generate tokens
        for token in generator(prompt):
            content += token
            live.update(render(content))

# Prompt 3: Once Upon a Time Extended

In [14]:
prompt = "Once upon a time"
content = prompt
n_tokens = 120

In [None]:
with tps_report(n_tokens):
    with Live(render(content)) as live:    
        
        # Generate tokens
        for token in generator(prompt, max_tokens=n_tokens):
            content += token
            live.update(render(content))

# Batched

In [14]:
prompts = (
    "I like traveling by train because",
    "Paris is an amazing place to visit,",
    "Once upon a time",
)

In [16]:
content = prompts
n_tokens = 120
temperature = 0.6

with Live(render(content)) as live:            
    for tokens in generator(prompts, max_tokens=n_tokens, temperature=temperature):
        content = [content[i] + token for i, token in enumerate(tokens)]
        live.update(render(content))

Output()

In [17]:
content = prompts
n_tokens = 100
temperature = 1.2

with Live(render(content)) as live:            
    for tokens in generator(prompts, max_tokens=n_tokens, temperature=temperature):
        content = [content[i] + token for i, token in enumerate(tokens)]
        live.update(render(content))

Output()