Skip to content

zoecarver/gemma4

Repository files navigation

Gemma 4 E4B TT-Lang Inference

Autoregressive inference for Google's Gemma 4 E4B (~4.5B effective parameter) LLM on Tenstorrent hardware, using TT-Lang compute kernels and TTNN for embedding, KV cache, and tensor manipulation.

Architecture

Gemma 4 E4B has 42 layers with a 5:1 sliding-to-global attention pattern, GQA (8 query heads, 2 KV heads), dual RoPE (10k theta sliding, 1M theta global with partial rotation), SwiGLU MLP, QK/V-norm, Per-Layer Embedding (PLE), and KV sharing (last 18 layers reuse earlier KV caches).

Files

  • gemma4.py -- model class (Gemma4Model) with decode_step and generate
  • kernels.py -- all TT-Lang kernels
  • prepare_weights.py -- loads HuggingFace safetensors, transposes for x@w layout, precomputes RoPE tables, saves .pt bundle
  • config.py -- model configs (E2B, E4B, 31B dataclasses)
  • test_kernels.py -- kernel unit tests

TT-Lang Kernels

  • linear_kernel / linear_kernel_hp -- matmul with streaming K-accumulation (hp variant: full K in DFB with f32 accumulate)
  • gelu_mul_kernel -- fused gelu(gate) * up for SwiGLU MLP
  • rotary_kernel -- rotary position embeddings with cos/sin broadcast
  • reshape_to_heads / reshape_from_heads -- head-batched layout transforms for GQA
  • flash_attention -- online flash attention with GQA support (8 cores, one per Q head, KV shared via stride)
  • residual_add_kernel -- elementwise addition
  • scalar_mul_kernel / elementwise_mul_kernel -- scalar and elementwise multiply
  • copy_kernel -- device-side tensor copy
  • softcap_kernel -- logit soft-capping (tanh-based)

TTNN Ops Used

  • ttnn.embedding -- token embedding lookup
  • ttnn.rms_norm -- RMSNorm (workaround for TT-Lang rsqrt precision bug with small values)
  • ttnn.gelu -- GELU activation (used for SwiGLU gate and PLE gate)
  • ttnn.kv_cache.update_cache_for_token_ -- KV cache update (sliding layers)
  • ttnn.slice, ttnn.concat, ttnn.reshape, ttnn.multiply -- tensor manipulation

Usage

# Prepare weights (on remote, ~16GB bundle)
python prepare_weights.py --model-dir /path/to/gemma-4-e4b --output /tmp/gemma4_e4b.pt

# Run inference
python -c "
import ttnn
from gemma4 import Gemma4Model, generate
from transformers import AutoTokenizer

device = ttnn.open_device(device_id=0)
model = Gemma4Model('/tmp/gemma4_e4b.pt', device, max_seq=128)
tokenizer = AutoTokenizer.from_pretrained('google/gemma-4-e4b')

tokens = tokenizer.encode('The capital of France is')
output, times = generate(model, tokens, max_tokens=32, temperature=0.7)
print(tokenizer.decode(output))
ttnn.close_device(device)
"

Sample Output

Prompt: 'The capital of France is'
Output:  a city that is full of history and culture. It is a city that is full
         of life and energy. It is a city that is full of people who

Prompt: 'The capital of France is' (temperature=0.7)
Output:  full of wonderful places to explore, but a visit to the Musee du Louvre
         is certainly one of the must-dos. The world-famous art museum is located

Prompt: 'Once upon a time,'
Output:  there was a girl named Sarah who loved to read. She had a huge collection
         of books, but she was always looking for something new and exciting to
         read. One day, she decided to try out a new genre: science fiction.

Speed: ~17 tok/sec (single Tenstorrent p300 chip, decode-only, max_seq=128).

Accuracy

  • Per-layer PCC > 0.993 vs HuggingFace reference (all 42 layers)
  • Final logit PCC = 0.999
  • Top-3 predicted tokens match HuggingFace exactly

Bugs and Learnings

  • update_cache_for_token_ corrupts adjacent positions when head_dim > 256. The op works correctly for head_dim 128 and 256 (up to 8 tiles) but corrupts neighboring cache rows for head_dim 384+ (12+ tiles). Workaround: split 512-dim heads into two 256-dim halves stored in separate caches, concat before attention.

  • TT-Lang rsqrt returns wrong values when fused with mul+fill. RMSNorm implemented purely in TT-Lang produces incorrect results for small input magnitudes. Workaround: use ttnn.rms_norm instead.

  • Global RoPE frequency denominator must use full head_dim. For partial rotation (e.g., 128 of 512 dims), the frequency formula 1/theta^(i/dim) must use dim=512 (full head), not dim=128 (rotary portion). The non-rotated dims get cos=1, sin=0 padding.

  • Gemma 4 attention scaling is 1.0, not 1/sqrt(head_dim). QK-norm handles the magnitude control.

About

gemma4 in tt-lang (e4b)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages