Autoregressive inference for Google's Gemma 4 E4B (~4.5B effective parameter) LLM on Tenstorrent hardware, using TT-Lang compute kernels and TTNN for embedding, KV cache, and tensor manipulation.
Gemma 4 E4B has 42 layers with a 5:1 sliding-to-global attention pattern, GQA (8 query heads, 2 KV heads), dual RoPE (10k theta sliding, 1M theta global with partial rotation), SwiGLU MLP, QK/V-norm, Per-Layer Embedding (PLE), and KV sharing (last 18 layers reuse earlier KV caches).
gemma4.py-- model class (Gemma4Model) withdecode_stepandgeneratekernels.py-- all TT-Lang kernelsprepare_weights.py-- loads HuggingFace safetensors, transposes for x@w layout, precomputes RoPE tables, saves.ptbundleconfig.py-- model configs (E2B, E4B, 31B dataclasses)test_kernels.py-- kernel unit tests
linear_kernel/linear_kernel_hp-- matmul with streaming K-accumulation (hp variant: full K in DFB with f32 accumulate)gelu_mul_kernel-- fused gelu(gate) * up for SwiGLU MLProtary_kernel-- rotary position embeddings with cos/sin broadcastreshape_to_heads/reshape_from_heads-- head-batched layout transforms for GQAflash_attention-- online flash attention with GQA support (8 cores, one per Q head, KV shared via stride)residual_add_kernel-- elementwise additionscalar_mul_kernel/elementwise_mul_kernel-- scalar and elementwise multiplycopy_kernel-- device-side tensor copysoftcap_kernel-- logit soft-capping (tanh-based)
ttnn.embedding-- token embedding lookupttnn.rms_norm-- RMSNorm (workaround for TT-Lang rsqrt precision bug with small values)ttnn.gelu-- GELU activation (used for SwiGLU gate and PLE gate)ttnn.kv_cache.update_cache_for_token_-- KV cache update (sliding layers)ttnn.slice,ttnn.concat,ttnn.reshape,ttnn.multiply-- tensor manipulation
# Prepare weights (on remote, ~16GB bundle)
python prepare_weights.py --model-dir /path/to/gemma-4-e4b --output /tmp/gemma4_e4b.pt
# Run inference
python -c "
import ttnn
from gemma4 import Gemma4Model, generate
from transformers import AutoTokenizer
device = ttnn.open_device(device_id=0)
model = Gemma4Model('/tmp/gemma4_e4b.pt', device, max_seq=128)
tokenizer = AutoTokenizer.from_pretrained('google/gemma-4-e4b')
tokens = tokenizer.encode('The capital of France is')
output, times = generate(model, tokens, max_tokens=32, temperature=0.7)
print(tokenizer.decode(output))
ttnn.close_device(device)
"Prompt: 'The capital of France is'
Output: a city that is full of history and culture. It is a city that is full
of life and energy. It is a city that is full of people who
Prompt: 'The capital of France is' (temperature=0.7)
Output: full of wonderful places to explore, but a visit to the Musee du Louvre
is certainly one of the must-dos. The world-famous art museum is located
Prompt: 'Once upon a time,'
Output: there was a girl named Sarah who loved to read. She had a huge collection
of books, but she was always looking for something new and exciting to
read. One day, she decided to try out a new genre: science fiction.
Speed: ~17 tok/sec (single Tenstorrent p300 chip, decode-only, max_seq=128).
- Per-layer PCC > 0.993 vs HuggingFace reference (all 42 layers)
- Final logit PCC = 0.999
- Top-3 predicted tokens match HuggingFace exactly
-
update_cache_for_token_corrupts adjacent positions when head_dim > 256. The op works correctly for head_dim 128 and 256 (up to 8 tiles) but corrupts neighboring cache rows for head_dim 384+ (12+ tiles). Workaround: split 512-dim heads into two 256-dim halves stored in separate caches, concat before attention. -
TT-Lang rsqrt returns wrong values when fused with mul+fill. RMSNorm implemented purely in TT-Lang produces incorrect results for small input magnitudes. Workaround: use
ttnn.rms_norminstead. -
Global RoPE frequency denominator must use full head_dim. For partial rotation (e.g., 128 of 512 dims), the frequency formula
1/theta^(i/dim)must use dim=512 (full head), not dim=128 (rotary portion). The non-rotated dims get cos=1, sin=0 padding. -
Gemma 4 attention scaling is 1.0, not 1/sqrt(head_dim). QK-norm handles the magnitude control.