# Text Completions

Exploring text completions using llama-jax.

# Setup

In [1]:
from contextlib import contextmanager
from sys import stdout
from time import time_ns, perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax import random

import llama_jax as ll

In [2]:
print(f"Available devices: {jax.devices()}")



Metal device set to: Apple M1 Max
Available devices: [METAL(id=0)]


I0000 00:00:1736128076.750670 17263942 service.cc:145] XLA service 0x600000b5db00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736128076.750687 17263942 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1736128076.752452 17263942 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1736128076.752465 17263942 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


In [3]:
@contextmanager
def tps_report(n_tokens: int):
    start_time = timer()

    yield
    
    stdout.write("\n\n")
    
    duration = (timer() - start_time) / 1000000000
    tps = n_tokens / duration
    
    print(f"Generated {n_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")

# Configure

In [4]:
# Parameters
max_tokens = 30
seed = time_ns()

In [5]:
# Configure generator
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B")
generator, key = ll.text.generator(config, key, max_tokens=max_tokens)

# Warmup

In [6]:
prompt = "My name is Julien and I like to"

In [7]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

My name is Julien and I like to ride my bike.
I'm a french guy living in London and I'm an engineer by trade. I like to travel and discover new places, especially

Generated 30 tokens in 20.4 s (1.5 tps)


# Prompt 0: Trains

In [8]:
prompt = "I like traveling by train because"

In [9]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

I like traveling by train because you can see a lot of things. I like it because it’s relaxing, and you can enjoy the scenery.
What is your favorite thing about traveling

Generated 30 tokens in 4.7 s (6.4 tps)


# Prompt 1: Paris

In [10]:
prompt = "Paris is an amazing place to visit,"

In [11]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

Paris is an amazing place to visit, there is so much to do and see. I think that I would have a hard time picking my favorite place in the city, so I will give

Generated 30 tokens in 3.0 s (10.1 tps)


# Prompt 2: Once Upon a Time

In [12]:
prompt = "Once upon a time"

In [13]:
with tps_report(max_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt):
        # Print next token
        stdout.write(token)

Once upon a time there was a beautiful princess. She was the most beautiful princess in the world, and she lived in a beautiful palace with her mother. The princess had

Generated 30 tokens in 4.8 s (6.3 tps)


# Prompt 3: Once Upon a Time Extended

In [14]:
prompt = "Once upon a time"
n_tokens = 120

In [15]:
with tps_report(n_tokens):
    # Print input tokens
    stdout.write(prompt)
    
    # Generate tokens
    for token in generator(prompt, max_tokens=n_tokens):
        # Print next token
        stdout.write(token)

Once upon a time there was a beautiful princess. She was the most beautiful princess in the world, and she lived in a beautiful palace with her mother. The princess had two very good friends, a lady named Lady Jane and a lady named Lady Jane.
One day, the princess was sitting in her room when she heard a knock on the door. She opened the door and there stood a beautiful lady, dressed in a beautiful dress. The lady was Lady Jane.
“Hello, princess,” said Lady Jane.
“Hello, Lady Jane,” said the princess.
“May I come in?” asked Lady Jane.
“Of

Generated 120 tokens in 62.6 s (1.9 tps)
