<a href="https://colab.research.google.com/github/google/tunix/blob/main/examples/qwen3_example.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q datasets

# Down the weights from Kaggle

In [None]:
from flax import nnx
import kagglehub
from tunix.models.qwen3 import model
from tunix.models.qwen3 import params

MODEL_CP_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

config = (
    model.ModelConfig.qwen3_0_6b()
)  # pick correponding config based on model version
qwen3 = params.create_model_from_safe_tensors(MODEL_CP_PATH, config)
nnx.display(qwen3)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)

# Run inference

In [None]:
from tunix.generate import sampler


def templatize(prompts):
  out = []
  for p in prompts:
    out.append(
        tokenizer.apply_chat_template(
            [
                {"role": "user", "content": p},
            ],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
    )
  return out


inputs = templatize([
    "which is larger 9.9 or 9.11?",
    "如何制作月饼?",
    "tell me your name, respond in Chinese",
])

sampler = sampler.Sampler(
    qwen3,
    tokenizer,
    sampler.CacheConfig(
        cache_size=256, num_layers=28, num_kv_heads=8, head_dim=128
    ),
)
out = sampler(inputs, max_generation_steps=128, echo=True)

for t in out.text:
  print(t)
  print("*" * 30)