In [1]:
import torch
from qwen3_model_with_kvcache import Qwen3ModelWithKVCache
from qwen3_tok import Qwen3Tokenizer, load_tokenizer
from qwen3_model_with_kvcache import load_weight
from qwen3_model_with_kvcache import KVCache
import time

In [2]:
from importlib.metadata import version

pkgs = [
    "huggingface_hub",  # to download pretrained weights
    "tokenizers",       # to implement the tokenizer
    "torch",            # to implement the model
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

huggingface_hub version: 1.2.3
tokenizers version: 0.22.1
torch version: 2.9.1


In [3]:
CHOOSE_MODEL = "0.6B"

QWEN3_CONFIG = {
        "vocab_size": 151_936,           # Vocabulary size
        "context_length": 40_960,        # Context length that was used to train the model
        "emb_dim": 1024,                 # Embedding dimension
        "n_heads": 16,                   # Number of attention heads
        "n_layers": 28,                  # Number of layers
        "hidden_dim": 3072,              # Size of the intermediate dimension in FeedForward
        "head_dim": 128,                 # Size of the heads in GQA
        "qk_norm": True,                 # Whether to normalize queries and keys in GQA
        "n_kv_groups": 8,                # Key-Value groups for grouped-query attention
        "rope_base": 1_000_000.0,        # The base in RoPE's "theta"
        "dtype": torch.bfloat16,         # Lower-precision dtype to reduce memory usage
    }

In [4]:
torch.manual_seed(18)
model = Qwen3ModelWithKVCache(QWEN3_CONFIG)

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

model.to(device)

Qwen3ModelWithKVCache(
  (tok_emb): Embedding(151936, 1024)
  (trf_blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (att): GroupedQueryAttention(
        (W_query): Linear(in_features=1024, out_features=2048, bias=False)
        (W_key): Linear(in_features=1024, out_features=1024, bias=False)
        (W_value): Linear(in_features=1024, out_features=1024, bias=False)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
        (q_norm): RMSNorm()
        (k_norm): RMSNorm()
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=1024, out_features=3072, bias=False)
        (fc2): Linear(in_features=1024, out_features=3072, bias=False)
        (fc3): Linear(in_features=3072, out_features=1024, bias=False)
      )
      (norm1): RMSNorm()
      (norm2): RMSNorm()
    )
  )
  (final_norm): RMSNorm()
  (out_head): Linear(in_features=1024, out_features=151936, bias=False)
)

In [6]:
load_weight(model, device, QWEN3_CONFIG, CHOOSE_MODEL)

In [7]:
tokenizer = load_tokenizer(CHOOSE_MODEL)

In [8]:
def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
    model.eval()

    with torch.no_grad():
        cache = KVCache(n_layers=model.cfg["n_layers"])
        model.reset_kv_cache()

        # Prime the cache with the initial context
        logits = model(token_ids, cache=cache)

        for _ in range(max_new_tokens):
            next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)

            if eos_token_id is not None and torch.all(next_token == eos_token_id):
                break

            yield next_token

            token_ids = torch.cat([token_ids, next_token], dim=1)

            # Feed only the new token to the model; cache handles history
            logits = model(next_token, cache=cache)

In [10]:
p = "Can you explain what MultiHead latent attention ?"
input_token_ids = tokenizer.encode(p)
input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)

gen_tokens = 0
t0 = time.time()

for token in generate_text_basic_stream(
    model=model,
    token_ids=input_token_ids_tensor,
    max_new_tokens=8192,
    eos_token_id=tokenizer.eos_token_id
):
    token_id = token.squeeze(0).tolist()
    gen_tokens += 1
    text = tokenizer.decode(token_id)
    print(
        tokenizer.decode(token_id),
        end="",
        flush=True
    )

decode_time = time.time() - t0
decode_toks_per_sec = gen_tokens / decode_time

print("\n\ntoken/s", decode_toks_per_sec)

<think>
Okay, so I need to explain what MultiHead Latent Attention is. Hmm, I remember that in neural networks, especially in transformer models, attention mechanisms are used to find the most relevant parts of the input. But how does MultiHead work?

Wait, MultiHead is a technique where the attention mechanism has multiple heads. Each head is a separate attention mechanism. So maybe each head is responsible for different parts of the input. For example, one head might focus on the first part, another on the second, and so on. That way, the model can capture different aspects of the input.

But then there's the latent attention part. Oh right, latent attention is a variation of the attention mechanism where the input is represented as a latent space. So instead of looking at the actual input, the model is looking at the latent representation. This might help in capturing more abstract or higher-level features.

So putting it all together, MultiHead Latent Attention is a method where th