# Demo: Observing Attention Configuration (MHA, MQA, GQA)

**Goal:** This demo shows how to inspect a model's configuration to determine its attention mechanism (Multi-Head, Multi-Query, or Grouped-Query Attention). Understanding this is key to predicting a model's memory usage for the KV Cache.

We will:
1. **Setup:** Install libraries and log in to Hugging Face.
2. **Load Configuration:** Efficiently load only the model's configuration file.
3. **Inspect Heads:** Extract the number of query heads and key/value heads.
4. **Determine Type:** Apply logic to identify the attention type.
5. **Interpret Results:** Understand the implications of the finding.

## Step 1: Setup Environment

First, we install the necessary libraries.

In [1]:
!pip install transformers torch accelerate



Next, we log in to Hugging Face to get access to models like Llama.

**Action Required:**
1. Go to `huggingface.co/settings/tokens`.
2. Create a new token with "read" permissions.
3. Paste your token into the cell below.

In [None]:
from huggingface_hub import login

# Paste your Hugging Face token here
HF_TOKEN = "<YOUR_HUGGING_FACE_TOKEN>"

login(token=HF_TOKEN)

Finally, import the libraries we'll use.

In [3]:
from transformers import AutoConfig

## Step 2: Load Model Configuration

For this task, we don't need to load the entire model (which can be billions of parameters and take up a lot of VRAM). We only need its architectural details, which are stored in the `config.json` file. The `AutoConfig` class lets us load just that.

In [4]:
model_name = "meta-llama/Llama-3.2-1B"

print(f"Loading configuration for: {model_name}")
try:
    config = AutoConfig.from_pretrained(model_name)
    print("Configuration loaded successfully!")
except Exception as e:
    print(f"Error loading model configuration: {e}")

Loading configuration for: meta-llama/Llama-3.2-1B
Configuration loaded successfully!


## Step 3: Inspect Key Attention Attributes

The two most important attributes for determining the attention type are:
- `num_attention_heads`: The number of attention heads for the **Query (Q)** projections.
- `num_key_value_heads`: The number of attention heads for the **Key (K) and Value (V)** projections.

In [5]:
# Extract the number of heads from the configuration object
num_q_heads = config.num_attention_heads
num_kv_heads = config.num_key_value_heads # This field is specific to MQA/GQA models

print(f"Extracted Attributes for '{model_name}':")
print(f"  Number of Query Heads (N_q):         {num_q_heads}")
print(f"  Number of Key/Value Heads (N_kv):    {num_kv_heads}")

Extracted Attributes for 'meta-llama/Llama-3.2-1B':
  Number of Query Heads (N_q):         32
  Number of Key/Value Heads (N_kv):    8


## Step 4: Determine the Attention Type

Now we can apply simple logic based on the two numbers we just extracted:
- If `N_q == N_kv`, it's **Multi-Head Attention (MHA)**.
- If `N_kv == 1`, it's **Multi-Query Attention (MQA)**.
- If `1 < N_kv < N_q`, it's **Grouped-Query Attention (GQA)**.

In [6]:
attention_type = "Unknown"

if num_kv_heads == num_q_heads:
    attention_type = "Multi-Head Attention (MHA)"
elif num_kv_heads == 1:
    attention_type = "Multi-Query Attention (MQA)"
elif 1 < num_kv_heads < num_q_heads:
    attention_type = "Grouped-Query Attention (GQA)"

print(f"Based on the head counts, the detected attention type is: {attention_type}")

Based on the head counts, the detected attention type is: Grouped-Query Attention (GQA)


## Step 5: Calculate Grouping Factor for GQA

Since we detected GQA, we can calculate the *grouping factor*—that is, how many Query heads share a single Key/Value head.

In [7]:
if attention_type == "Grouped-Query Attention (GQA)":
    if num_q_heads % num_kv_heads == 0:
        group_factor = num_q_heads // num_kv_heads
        print(f"Grouping Factor = {num_q_heads} (Query Heads) / {num_kv_heads} (KV Heads) = {group_factor}")
        print(f"This means every {group_factor} Query heads share a single set of Key and Value heads.")
    else:
        print("GQA detected, but heads are not evenly divisible.")

Grouping Factor = 32 (Query Heads) / 8 (KV Heads) = 4
This means every 4 Query heads share a single set of Key and Value heads.


## Final Interpretation

For the model **`meta-llama/Llama-3.2-1B`**, we have confirmed it uses **Grouped-Query Attention (GQA)**.

**Why this matters:**
- **Reduced Memory:** A standard MHA model would have needed 32 sets of Key/Value heads in its KV Cache. By using only 8, GQA reduces the KV Cache size by a factor of 4 (32 / 8).
- **Faster Inference:** A smaller KV Cache means less data needs to be read from slow GPU memory (HBM) at each generation step, which reduces the memory bandwidth bottleneck and speeds up inference.
- **Longer Context:** The memory savings from GQA allow the model to handle longer sequences of text before running out of VRAM.