# Text Completions

Exploring text completions using llama-jax.

# Setup

In [1]:
from contextlib import contextmanager
from sys import stdout
from time import time_ns, perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax import random

import llama_jax as ll

In [2]:
print(f"Available devices: {jax.devices()}")



Metal device set to: Apple M1 Max
Available devices: [METAL(id=0)]


I0000 00:00:1736130976.977310 17303936 service.cc:145] XLA service 0x60000049c500 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736130976.977319 17303936 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1736130976.978433 17303936 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1736130976.978439 17303936 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


In [3]:
@contextmanager
def tps_report(n_tokens: int):
    start_time = timer()

    yield
    
    stdout.write("\n\n")
    
    duration = (timer() - start_time) / 1000000000
    tps = n_tokens / duration
    
    print(f"Generated {n_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")

# Configure

In [4]:
# Parameters
max_tokens = 30
seed = time_ns()

In [5]:
# Configure generator
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B")
generator, key = ll.text.generator(config, key, max_tokens=max_tokens)

# Warmup

In [6]:
prompt = "My name is Julien and I like to"

In [7]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

My name is Julien and I like to play games.
I am a passionate gamer who loves to play games, and I want to share my passion with the world. I am a member of

Generated 30 tokens in 21.9 s (1.4 tps)


# Prompt 0: Trains

In [8]:
prompt = "I like traveling by train because"

In [9]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

I like traveling by train because it's very relaxing. I enjoy watching the scenery go by and reading a book. It's very different from driving a car and I like that.


Generated 30 tokens in 3.0 s (10.1 tps)


# Prompt 1: Paris

In [10]:
prompt = "Paris is an amazing place to visit,"

In [11]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

Paris is an amazing place to visit, but the food can be expensive. However, there are plenty of affordable restaurants in Paris that offer delicious and authentic cuisine. Here are some of the best

Generated 30 tokens in 1.1 s (26.8 tps)


# Prompt 2: Once Upon a Time

In [12]:
prompt = "Once upon a time"

In [13]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

Once upon a time, the only way to get your hands on a good book was to go to the library. In fact, I can remember as a child going to

Generated 30 tokens in 2.8 s (10.6 tps)


# Prompt 3: Once Upon a Time Extended

In [14]:
prompt = "Once upon a time"
n_tokens = 120

In [15]:
with tps_report(n_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt, max_tokens=n_tokens):
        # Print next token
        stdout.write(token)

Once upon a time, the only way to get your hands on a good book was to go to the library. In fact, I can remember as a child going to the library every week with my mother. We would walk around the stacks, looking for the books we wanted to borrow, and then we would go to the checkout desk and wait in line to get our books.
As I got older, I started to get my books from the library less and less. I would borrow them from friends, or I would buy them at the bookstore. But eventually, I stopped going to the library altogether. I had all

Generated 120 tokens in 55.1 s (2.2 tps)
