### LLM Inference Example

This notebook contains a basic inference example for using our `ttml` Python API to build, load, and run a large language model from Hugging Face on our TT hardware. By default, it is set to create and load a GPT2 model, but this notebook can quickly and easily be edited to use any of the LLMs that the tt-train project currently supports. 

Below, in the first cell, we have our imports and basic directory housekeeping.

In [1]:
import os, sys, random
import numpy as np  # For numpy arrays
from dataclasses import dataclass # For configuration classes
from huggingface_hub import hf_hub_download # To download safetensors from Hugging Face
from transformers import AutoTokenizer
from yaml import safe_load # To read YAML configs
from pathlib import Path

sys.path.append(f"{os.getenv('TT_METAL_RUNTIME_ROOT')}/tt-train/sources/ttml")
import ttml
from ttml.common.config import get_training_config, load_config
from ttml.common.utils import set_seed, round_up_to_tile
from ttml.common.model_factory import TransformerModelFactory

# Change working directory to tt-train
os.chdir(f"{os.environ['TT_METAL_RUNTIME_ROOT']}/tt-train")


Use the cell below to change global parameters in this notebook. 

`OUTPUT_TOKENS` : the length of the generated text in token (not characters!) 

`TEMPERATURE`   : sampling temperature; set to 0 to disable sampling in `generate_with_tt()`

`SEED`          : randomization seed (for reproducibility)

In [2]:
OUTPUT_TOKENS = 750
WITH_SAMPLING = True
TEMPERATURE = 0.8
SEED = 42
CONFIG = "training_shakespeare_llama3_gpt2s_size.yaml"

set_seed(SEED)

model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

While the notebook is currently configured for GPT2, you can quickly change the tokenizer you want to use by changing the input to `from_pretrained()` below.

In [3]:
# Load the tokenizer from Hugging Face and the transformer config from YAML
tokenizer = AutoTokenizer.from_pretrained(model_path)
training_config = get_training_config(CONFIG)
model_yaml = load_config(training_config.model_config, configs_root=os.getcwd())

As above, the call to `hf_hub_download()` will download (or otherwise find on your local system) the SafeTensors model weight file for GPT2, but can be updated to download other SafeTensors files.

In [None]:
# # Get safetensors
safetensors_path = hf_hub_download(repo_id=model_path, filename="model.safetensors")
safetensors_path = safetensors_path.replace("model.safetensors","")

print(f"Safetensors path: {safetensors_path}")

orig_vocab_size = tokenizer.vocab_size

tt_model_factory = TransformerModelFactory(model_yaml)
tt_model_factory.transformer_config.vocab_size = orig_vocab_size

max_sequence_length = tt_model_factory.transformer_config.max_sequence_length

tt_model = tt_model_factory.create_model()
tt_model.load_from_safetensors(safetensors_path)
tt_model

padded_vocab_size = round_up_to_tile(orig_vocab_size, 32)

if orig_vocab_size != padded_vocab_size:
    print(f"Padding vocab size for tilization: original {orig_vocab_size} -> padded {padded_vocab_size}")


---

## Part 1: Model Setup WITHOUT KV Cache

This section sets up a model that does NOT use KV cache. Each generation step performs a full forward pass through the entire sequence, which is slower but simpler.


`generate_with_tt()` uses TT hardware acceleration to generate output from the chosen LLM

In [None]:
def build_causal_mask(T: int) -> ttml.autograd.Tensor:
    # [1,1,T,T] float32 with 1s for allowed positions (i >= j), else 0\n",
    m = np.tril(np.ones((T, T), dtype=np.float32))
    return ttml.autograd.Tensor.from_numpy(m.reshape(1, 1, T, T), ttml.Layout.TILE, ttml.autograd.DataType.BFLOAT16)

def build_logits_mask(vocab_size: int, padded_vocab_size: int) -> ttml.autograd.Tensor:
    logits_mask = np.zeros((1, 1, 1, padded_vocab_size), dtype=np.float32)
    logits_mask[:, :, :, vocab_size:] = 1e4
    return ttml.autograd.Tensor.from_numpy(logits_mask, ttml.Layout.TILE, ttml.autograd.DataType.BFLOAT16)   # [1,1,1,T], bfloat16"

def generate_with_tt(model, prompt_tokens):
    """Generate text without KV cache (full sequence forward pass each step)."""
    import time
    
    ttml.autograd.AutoContext.get_instance().set_gradient_mode(ttml.autograd.GradMode.DISABLED)
    model.eval()

    logits_mask_tensor = None

    if padded_vocab_size != orig_vocab_size:
        logits_mask_tensor = build_logits_mask(orig_vocab_size, padded_vocab_size)

    causal_mask = build_causal_mask(max_sequence_length)  # [1,1,seq_len,seq_len], float32
    padded_prompt_tokens = np.zeros((1, 1, 1, max_sequence_length), 
                                    dtype=np.uint32)

    start_idx = 0
    prompt_len = len(prompt_tokens)

    print("=" * 80)
    print("Running Inference WITHOUT KV Cache (Full Forward Pass Each Step)")
    print(f"Prompt tokens: {prompt_tokens[:10]}{'...' if len(prompt_tokens) > 10 else ''}")
    print(f"Prompt length: {prompt_len}")
    print(f"Max new tokens: {OUTPUT_TOKENS}")
    print("=" * 80)
    print("\nGenerated text:")
    print("************************************")
    
    start_time = time.time()

    generated_tokens = prompt_tokens.copy()
    
    for token_idx in range(OUTPUT_TOKENS):

        if len(prompt_tokens) > max_sequence_length:
            start_idx = len(prompt_tokens) - max_sequence_length

        # padded_prompt_tokens[0, 0, 0, :transformer_cfg["max_sequence_length"]] = 0
        padded_prompt_tokens[0, 0, 0, :len(prompt_tokens)] = prompt_tokens[start_idx:]
        padded_prompt_tensor = ttml.autograd.Tensor.from_numpy(
            padded_prompt_tokens,
            ttml.Layout.ROW_MAJOR,
            ttml.autograd.DataType.UINT32)  # [1,1,1, max_seq_len], uint32

        logits = model(padded_prompt_tensor, causal_mask)  # out=[1,1,seq_len, vocab_size], bf16


        next_token_tensor = ttml.ops.sample.sample_op(logits, TEMPERATURE, np.random.randint(low=1e7), logits_mask_tensor)  # out=[1,1,seq_len,1], uint32

        next_token_idx = max_sequence_length - 1 if len(prompt_tokens) > max_sequence_length else len(prompt_tokens) - 1
        next_token = int(next_token_tensor.to_numpy().flatten()[next_token_idx])
        generated_tokens.append(next_token)

        output = tokenizer.decode([next_token], skip_special_tokens=False)

        prompt_tokens.append(next_token)
        print(output, end='', flush=True)

    end_time = time.time()
    total_duration_ms = (end_time - start_time) * 1000
    new_tokens = len(prompt_tokens) - prompt_len
    
    print("\n************************************")
    print("\n=== GENERATION SUMMARY ===")
    print(f"Total tokens generated: {len(prompt_tokens)}")
    print(f"  Prompt: {prompt_len} tokens")
    print(f"  New: {new_tokens} tokens")
    print(f"\nTotal time: {total_duration_ms:.2f} ms")
    print(f"Average time per token: {total_duration_ms / new_tokens if new_tokens > 0 else 0:.2f} ms")
    print("=" * 80)
    print("\n")
    print("Final result:")
    print(tokenizer.decode(generated_tokens, skip_special_tokens=False))

---

## Part 2: Model Setup WITH KV Cache

This section sets up a model that uses KV cache for efficient inference. The KV cache stores previously computed key-value pairs from the attention mechanism, allowing each generation step to only process the newly generated token instead of recomputing the entire sequence. This significantly speeds up generation, especially for longer sequences.

**Key differences from non-cache generation:**

- **Prefill phase**: First step processes the entire prompt and stores KV pairs in cache
- **Decode phase**: Subsequent steps only process the last generated token, reusing cached KV pairs
- **Performance**: Much faster than non-cache generation, with speedup increasing as sequence length grows


In [None]:
TILE_SIZE = 32

def round_up_to_tile(value: int) -> int:
    """Round up to nearest multiple of TILE_SIZE."""
    return ((value + TILE_SIZE - 1) // TILE_SIZE) * TILE_SIZE

def create_causal_mask_kv_cache(query_seq_len: int, prompt_len: int = 0) -> ttml.autograd.Tensor:
    """Create a causal attention mask for autoregressive generation with KV cache.
    
    This matches the C++ implementation exactly.
    
    Args:
        query_seq_len: Length of query sequence
        prompt_len: Length of prompt (for decode mode, this is the cache position)
    
    Returns:
        Causal mask tensor
    """
    whole_seq_len = prompt_len + query_seq_len
    padded_query_len = round_up_to_tile(query_seq_len)
    padded_whole_len = round_up_to_tile(whole_seq_len)
    
    # Mask shape: [padded_query_len, padded_whole_len] - query_len x key_len
    mask_data = np.zeros((padded_query_len, padded_whole_len), dtype=np.float32)
    
    # Fill mask: token i can attend to positions 0 through i + prompt_len (inclusive)
    # This matches C++: for (uint32_t j = 0; j <= i + prompt_len; ++j)
    # range(n) gives [0, 1, 2, ..., n-1], so range(prompt_len + i + 1) gives [0, 1, 2, ..., prompt_len + i]
    for i in range(query_seq_len):
        for j in range(prompt_len + i + 1):
            mask_data[i, j] = 1.0
    
    # Reshape to [1, 1, padded_query_len, padded_whole_len]
    mask_data = mask_data.reshape(1, 1, padded_query_len, padded_whole_len)
    mask_tensor = ttml.autograd.Tensor.from_numpy(
        mask_data,
        layout=ttml.Layout.TILE,
        new_type=ttml.autograd.DataType.BFLOAT16
    )
    
    return mask_tensor


def tokens_to_tensor_kv_cache(tokens: list) -> ttml.autograd.Tensor:
    """Create tensor from token IDs with padding to nearest multiple of 32.
    
    This matches the C++ implementation exactly.
    
    Args:
        tokens: List of token IDs
    
    Returns:
        Token tensor with padding
    """
    actual_len = len(tokens)
    padded_len = round_up_to_tile(actual_len)
    
    # Pad tokens with zeros to reach padded length
    padded_tokens = np.zeros(padded_len, dtype=np.uint32)
    for i in range(actual_len):
        padded_tokens[i] = tokens[i]
    
    # Reshape to [1, 1, 1, padded_len]
    padded_tokens = padded_tokens.reshape(1, 1, 1, padded_len)
    tokens_tensor = ttml.autograd.Tensor.from_numpy(
        padded_tokens,
        layout=ttml.Layout.ROW_MAJOR,
        new_type=ttml.autograd.DataType.UINT32
    )
    
    return tokens_tensor


def sample_token_from_logits(logits: ttml.autograd.Tensor, position: int) -> int:
    """Sample next token using greedy decoding (argmax).
    
    Args:
        logits: Logits tensor
        position: Position to sample from
    
    Returns:
        Token ID with highest logit
    """
    logits_np = logits.to_numpy()
    logits_host = logits_np.flatten()
    
    shape = logits.shape()
    vocab_size = shape[-1]
    last_token_offset = (position - 1) * vocab_size
    
    # Find token with highest logit value
    max_idx = 0
    max_val = logits_host[last_token_offset]
    
    for i in range(1, vocab_size):
        val = logits_host[last_token_offset + i]
        if val > max_val:
            max_val = val
            max_idx = i
    
    return max_idx


def generate_with_tt_kv_cache(model, prompt_tokens, transformer_config, use_sampling=True):
    """Generate text with KV cache for efficient inference.
    
    Args:
        model: LLaMA model instance
        prompt_tokens: Initial prompt token IDs
        transformer_config: Model config with num_blocks, num_groups, embedding_dim, max_sequence_length
        use_sampling: Whether to use temperature sampling (if False, uses greedy decoding)
    """
    import time
    
    ttml.autograd.AutoContext.get_instance().set_gradient_mode(ttml.autograd.GradMode.DISABLED)
    model.eval()
    
    # Create KV cache
    batch_size = 1
    num_layers = transformer_config.num_blocks
    num_groups = transformer_config.num_groups
    max_seq_len = transformer_config.max_sequence_length
    head_dim = transformer_config.embedding_dim // transformer_config.num_heads
    
    kv_cache_config = ttml.models.KvCacheConfig(
        num_layers, batch_size, num_groups, max_seq_len, head_dim
    )
    kv_cache = ttml.models.KvCache(kv_cache_config)
    
    # Reset KV cache for new sequence
    kv_cache.reset()
    
    generated_tokens = prompt_tokens.copy()
    prompt_len = len(prompt_tokens)
    
    logits_mask_tensor = None
    if padded_vocab_size != orig_vocab_size:
        logits_mask_tensor = build_logits_mask(orig_vocab_size, padded_vocab_size)
    
    print("=" * 80)
    print("Running Inference with KV Cache")
    print(f"Prompt tokens: {prompt_tokens[:10]}{'...' if len(prompt_tokens) > 10 else ''}")
    print(f"Prompt length: {prompt_len}")
    print(f"Max new tokens: {OUTPUT_TOKENS}")
    print("=" * 80)
    print("\nGenerated text:")
    print("************************************")
    
    start_time = time.time()
    
    # Generate tokens one by one
    for step in range(min(OUTPUT_TOKENS, max_seq_len - prompt_len)):
        # For first step (prefill): use all prompt tokens
        # For subsequent steps (decode): use only the last generated token  
        processed_tokens = 0
        if kv_cache.get_cache_position() == 0:
            # Prefill: process entire prompt
            input_tokens = generated_tokens
        else:
            # Decode: process only last token
            input_tokens = [generated_tokens[-1]]
            processed_tokens = len(generated_tokens)-1
        
        token_tensor = tokens_to_tensor_kv_cache(input_tokens)
        
        # Create causal mask
        # For prefill: query_len = prompt_len, prompt_len = 0 (all tokens can attend to previous)
        # For decode: query_len = 1, prompt_len = cache_position (new token can attend to all cached tokens)
        # This matches C++: create_causal_mask(device, input_tokens.size(), processed_tokens)
        mask = create_causal_mask_kv_cache(len(input_tokens), processed_tokens)
        new_tokens = len(input_tokens)
        logits = model(token_tensor, mask, kv_cache=kv_cache, new_tokens=new_tokens)
        
        # Sample next token
        # The logits tensor has shape [1, 1, seq_len, vocab_size] where seq_len may be padded
        # We need to extract the token at the last actual position (len(input_tokens) - 1)
        if use_sampling:
            next_token_tensor = ttml.ops.sample.sample_op(
                logits, TEMPERATURE, np.random.randint(low=1e7), logits_mask_tensor
            )
            next_token_idx = len(input_tokens) - 1
            next_token = int(next_token_tensor.to_numpy().flatten()[next_token_idx])
        else:
            # Greedy decoding - extract logits at last position and find argmax
            next_token = int(sample_token_from_logits(logits, len(input_tokens)))
        
        output = tokenizer.decode([next_token], skip_special_tokens=False)
        generated_tokens.append(next_token)
        print(output, end='', flush=True)
    
    end_time = time.time()
    total_duration_ms = (end_time - start_time) * 1000
    new_tokens = len(generated_tokens) - prompt_len

    kv_cache.reset()
    
    print("\n************************************")
    print("\n=== GENERATION SUMMARY ===")
    print(f"Total tokens generated: {len(generated_tokens)}")
    print(f"  Prompt: {prompt_len} tokens")
    print(f"  New: {new_tokens} tokens")
    print(f"\nTotal time: {total_duration_ms:.2f} ms")
    print(f"Average time per token: {total_duration_ms / new_tokens if new_tokens > 0 else 0:.2f} ms")
    print("=" * 80)
    print("\n")
    print("Final result:")
    print(tokenizer.decode(generated_tokens, skip_special_tokens=False))


---

## Part 3: Generation Examples

### 3.1: Generation WITHOUT KV Cache

Examples using the non-cache model. These will be slower but demonstrate the baseline approach.


In [None]:
prompt_str = "The difference between cats and dogs is:"
prompt_tokens = tokenizer.encode(prompt_str)
print("Generating with TT (WITHOUT KV Cache):")
generate_with_tt(tt_model, prompt_tokens.copy())

Generating with TT (WITHOUT KV Cache):


NameError: name 'tt_model' is not defined

### 3.2: Generation WITH KV Cache 

Now let's generate text using the KV cache-enabled model. This will be much faster than the regular generation, especially for longer sequences.


In [None]:
prompt_str = "Who are you?"
prompt_tokens = tokenizer.encode(prompt_str)
print("Generating with TT (KV Cache):")
generate_with_tt_kv_cache(tt_model, prompt_tokens.copy(), tt_model_factory.transformer_config, use_sampling=True)


Generating with TT (KV Cache):
Running Inference with KV Cache
Prompt tokens: [1, 11644, 526, 366, 29973]
Prompt length: 5
Max new tokens: 750

Generated text:
************************************

Initializing KV cache:
    Batch size: 1
    Num layers: 22
    Num groups: 4
    Max sequence length: 1024
    Head dim: 64
KV cache initialized successfully

Char

les:Iamagenie.

Narrator:Agenie?Whatdoesthatevenmean?

Charles:Agenieisacreatureofpureanduntamedenergy,createdbymagic.

Narrator:Andwhatmagiccanyouperform?

Charles:Icangrantwishes.

Narrator:Buthow?

Charles:Bygrantingwishes,youcanchangeyourlifeinwaysyoucouldneverhaveimaginedbefore.

Narrator:That'samazing!Canyougrantoneforme?

Charles:Icangrantyouonewish,butremember,it'sjustonewish.

Narrator:Okay,I'llthinkaboutit.

Scene3:

Lucasissittinginasmall,dimlylitroomwithnowindows.He'swearingalong-sleevedshirtandjeans.He'sfeelingnervousandunsure.

Narrator:Welcometothesecretsociety,Lucas.

Lucas:(sighs)Idon'tknowwhattoexpect.

Narrator:You'rejoiningoneofthemostprestigiousandsecretiveorganizationsintheworld.

Lucas:(confidently)I'mreadyforanything.

Narrator:Butfirst,weneedtogothroughaninitiationprocess.

Scene4:

Lucasisstandinginfrontofalarge,ornatechest.He'sholdingasmallkeyinhishand.

Narrator:Yourfirstorderofbusinessistoopenthischestandrevealitscontents.

Lucas:(sighs)Okay,butwhatifIdon't