# Chat Completions

Exploring chat completions with llama-jax.

# Setup

In [1]:
from contextlib import contextmanager
from sys import stdout
from time import time_ns, perf_counter_ns as timer

import jax
from jax import numpy as jnp
from jax import random

import llama_jax as ll

In [2]:
print(f"Available devices: {jax.devices()}")



Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

Available devices: [METAL(id=0)]


I0000 00:00:1740264879.152067 75826958 service.cc:145] XLA service 0x11b85cbd0 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1740264879.152079 75826958 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1740264879.153610 75826958 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1740264879.153621 75826958 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.


In [3]:
@contextmanager
def tps_report(n_tokens: int):
    start_time = timer()

    yield
    
    stdout.write("\n\n")
    
    duration = (timer() - start_time) / 1000000000
    tps = n_tokens / duration
    
    print(f"Generated {n_tokens} tokens in {duration:0.1f} s ({tps:0.1f} tps)")

In [4]:
# def chat(history: Sequence[Message], content: str) -> Sequence[Message]:
#     messages = (*history, Message(role="user", content=content))
    

# Configure

In [5]:
# Parameters
seed = time_ns()

In [6]:
# Configure generator
key = random.key(seed)
config = ll.checkpoint.load_config("Llama3.2-3B-Instruct")
generator, key = ll.chat.generator(config, key)

# Warmup

In [7]:
prompt = "What is the capital of France?"

In [8]:
response = next(generator([{"role": "user", "content": prompt}]))

In [9]:
print(response.messages[-1].content)

The capital of France is Paris.


# Prompt 0: Counting

In [10]:
prompt = "Count from 0 to 10."

In [11]:
for event in generator([{"role": "user", "content": prompt}], stream=True):
    stdout.write("\n\n" if event.delta is None else event.delta.content)

Here we go:

0, 1, 2, 3, 4, 5, 6, 7, 8, 9

