# Text Completions

Exploring text completions using llama-jax.

# Setup

In [1]:
from collections.abc import Sequence
from contextlib import contextmanager
from sys import stdout
from time import time_ns as seed, perf_counter_ns as timer

import jax
from jax import random
from rich.live import Live
from rich.table import Table

import llama_jax as ll



Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



I0000 00:00:1741022517.078522 11838555 service.cc:145] XLA service 0x121c59c80 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1741022517.078539 11838555 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1741022517.079948 11838555 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1741022517.079961 11838555 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


In [2]:
@contextmanager
def tps_report(n_tokens: int):
    start_time = timer()

    yield
        
    duration = (timer() - start_time) / 1000000000
    tps = n_tokens / duration
    
    print(f"\nGenerated {n_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")


def render(content: str | Sequence[str], token: int | None = None):
    if isinstance(content, str):
        content = [content]

    n = len(content)

    title = f"Tokens {token+1}" if token is not None else None

    table = Table(
        show_header=True, 
        show_edge=False, 
        expand=True,
    )
    
    for i in range(n):
        table.add_column(header=(title if i == 0 else ""), ratio=1/n)

    table.add_row(*content)
    
    return table    

In [3]:
# Parameters
max_tokens = 30

In [4]:
# Configure generator
config = ll.checkpoint.load_config("Llama3.2-3B")
generator = ll.text.generator(config, key=random.key(seed()), max_tokens=max_tokens)

# Warmup

In [5]:
prompt = "My name is Julien and I like to"

In [6]:
with tps_report(max_tokens):
    for token in generator(prompt):
        stdout.write(".")

..............................
Generated 30 tokens in 25.1 s (1.2 tps)


# Prompt 0: Trains

In [7]:
prompt = "I like traveling by train because"

In [8]:
content = prompt

with tps_report(max_tokens), Live(render(content)) as live:    
    for i, token in enumerate(generator(prompt)):
        content += token
        live.update(render(content, token=i))

Output()


Generated 30 tokens in 8.1 s (3.7 tps)


# Prompt 1: Paris

In [9]:
prompt = "Paris is an amazing place to visit,"

In [10]:
content = prompt

with tps_report(max_tokens), Live(render(content)) as live:    
    for i, token in enumerate(generator(prompt)):
        content += token
        live.update(render(content, token=i))

Output()


Generated 30 tokens in 5.9 s (5.1 tps)


# Prompt 2: Once Upon a Time

In [11]:
prompt = "Once upon a time"

In [12]:
content = prompt

with tps_report(max_tokens), Live(render(content)) as live:    
    for i, token in enumerate(generator(prompt)):
        content += token
        live.update(render(content, token=i))

Output()


Generated 30 tokens in 8.0 s (3.8 tps)


# Batched

In [13]:
prompts = (
    "I like traveling by train because",
    "Paris is an amazing place to visit,",
    "Once upon a time",
)

In [15]:
content = prompts

with Live(render(content)) as live:    
    for i, tokens in enumerate(generator(prompts, max_tokens=50)):
        content = [content[i] + token for i, token in enumerate(tokens)]
        live.update(render(content, token=i))

Output()