# Multi-Head Attention

Multi-Head Attention is an important part of all Transformer-based models.
This tutorial will show how to write it and how to then optimize it.

In [1]:
import os


In [2]:
import time
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
dispatch_core_type = ttnn.device.DispatchCoreType.ETH
if os.environ.get("ARCH_NAME") and "grayskull" in os.environ.get("ARCH_NAME"):
    dispatch_core_type = ttnn.device.DispatchCoreType.WORKER
device = ttnn.open_device(device_id=device_id, l1_small_size=8192, dispatch_core_config=ttnn.device.DispatchCoreConfig(dispatch_core_type))

2025-07-02 10:31:43.775 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/ubuntu/.cache/ttnn,model_cache_path=/home/ubuntu/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


2025-07-02 10:31:43.820 | info     |   SiliconDriver | Opened PCI device 0; KMD version: 2.0.0; API: 2; IOMMU: disabled (pci_device.cpp:198)
2025-07-02 10:31:43.820 | info     |   SiliconDriver | Opened PCI device 0; KMD version: 2.0.0; API: 2; IOMMU: disabled (pci_device.cpp:198)
2025-07-02 10:31:43.833 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:174)
2025-07-02 10:31:43.833 | info     |   SiliconDriver | Opened PCI device 0; KMD version: 2.0.0; API: 2; IOMMU: disabled (pci_device.cpp:198)
2025-07-02 10:31:43.833 | info     |   SiliconDriver | Opened PCI device 0; KMD version: 2.0.0; API: 2; IOMMU: disabled (pci_device.cpp:198)
2025-07-02 10:31:43.838 | info     |   SiliconDriver | Opened PCI device 0; KMD version: 2.0.0; API: 2; IOMMU: disabled (pci_device.cpp:198)
2025-07-02 10:31:43.838 | info     |   SiliconDriver | Opened PCI device 0; KMD version: 2.0.0; API: 2; IOMMU: disabled (pci_device.cpp:198)
2025-07-02 10:31:43.843 | info     |   Silicon

## Enable program cache

In [3]:
device.enable_program_cache()

2025-07-02 10:31:44.774 | info     |           Metal | Enabling program cache on MeshDevice 1 (mesh_device.cpp:537)


## Write Multi-Head Attention using ttnn

Multi-head can be implemented in `torch` using just 6 operations:

1. `torch.matmul`
2. `torch.add`
3. `torch.reshape`
4. `torch.permute`
5. `torch.mul`
6. `torch.softmax`

`ttnn` provides the exact same APIs to do that and therefore multi-head attention can be implemented in a very similar fashion. Except, when using `ttnn`, the user should be mindful of the tensor layout.

In [4]:
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    num_heads,
):
    fallback_reshape = ttnn.get_fallback_function(ttnn.reshape) 
       
    batch_size, sequence_size, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)
    query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)
    key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)
    value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    attention_scores += attention_mask
    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)
    context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

Now that the model is written, let's create input tensors to run it and test it

## Configuration

In [5]:
batch_size = 8
sequence_size = 384
num_heads = 16
head_size = 64
hidden_size = num_heads * head_size

## Initialize activations and weights

In [6]:
hidden_states = ttnn.rand((batch_size, sequence_size, hidden_size), dtype=ttnn.bfloat16, device=device)
attention_mask = ttnn.rand((batch_size, 1, 1, sequence_size), dtype=ttnn.bfloat16, device=device)
query_weight = ttnn.rand((hidden_size, hidden_size), dtype=ttnn.bfloat16, device=device)
query_bias = ttnn.rand((hidden_size,), dtype=ttnn.bfloat16, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.rand((hidden_size, hidden_size), dtype=ttnn.bfloat16, device=device)
key_bias = ttnn.rand((hidden_size,), dtype=ttnn.bfloat16, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.rand((hidden_size, hidden_size), dtype=ttnn.bfloat16, device=device)
value_bias = ttnn.rand((hidden_size,), dtype=ttnn.bfloat16, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.rand((hidden_size, hidden_size), dtype=ttnn.bfloat16, device=device)
output_bias = ttnn.rand((hidden_size,), dtype=ttnn.bfloat16, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

## Run the first iteration of Multi-Head Attention

In [7]:
start = time.time()
multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

In [8]:
print(f"Multi-head attention ran in {duration} seconds for the first iteration")

Multi-head attention ran in 4.554651975631714 seconds for the first iteration


## Run a subsequent iteration of Multi-Head Attention

In [9]:
start = time.time()
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

In [10]:
print(f"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Multi-head attention ran in 0.05135488510131836 seconds for the subsequent iteration because of the program cache


## Write optimized version of Multi-Head Attention

Optimized version of the multi-head attention can be written by:

- Tilizing all of the tensors ahead of time
- Using more performant matmuls that fuse bias and specify the number of cores they execute on
- Putting every tensor into L1
- Using bfloat8_b data_type
- Using custom `ttnn.transformer` operations instead of `ttnn.permute` and `ttnn.reshape`

`ttnn.deallocate` calls are needed because otherwise, the cores on the device will run out of the L1 memory

In [11]:
def optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    fused_qkv_weight,
    fused_qkv_bias,
    self_output_weight,
    self_output_bias,
    *,
    num_heads,
    num_cores_x=12,
):
    batch_size, _, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads
    
    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)

    fused_qkv_output = ttnn.linear(
        hidden_states,
        fused_qkv_weight,
        bias=fused_qkv_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )

    (
        query,
        key,
        value,
    ) = ttnn.transformer.split_query_key_value_and_split_heads(
        fused_qkv_output,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        num_heads=num_heads,
    )
    ttnn.deallocate(fused_qkv_output)

    attention_scores = ttnn.matmul(
        query,
        key,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(query)
    ttnn.deallocate(key)

    attention_probs = ttnn.transformer.attention_softmax_(attention_scores, attention_mask=attention_mask, head_size=head_size)

    context_layer = ttnn.matmul(
        attention_probs,
        value,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(attention_probs)

    context_layer_after_concatenate_heads = ttnn.transformer.concatenate_heads(
        context_layer,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )
    ttnn.deallocate(context_layer)

    self_output = ttnn.linear(
        context_layer_after_concatenate_heads,
        self_output_weight,
        bias=self_output_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(context_layer_after_concatenate_heads)

    return self_output

## Pre-process the parameters of the optimized model

1. Fuse QKV weights and biases
2. Reshape and tilize for the optimized operations using preprocess_linear_weight and preprocess_linear_bias
3. Move to device

In [12]:
qkv_weight = ttnn.concat([query_weight, key_weight, value_weight], dim=-1)
qkv_weight = ttnn.to_device(ttnn.to_layout(qkv_weight, layout=ttnn.TILE_LAYOUT), device=device)

qkv_bias = ttnn.from_torch(torch.cat([ttnn.to_torch(query_bias), ttnn.to_torch(key_bias), ttnn.to_torch(value_bias)], dim=-1), device=device)
qkv_bias = ttnn.reshape(qkv_bias, (1, -1))
qkv_bias = ttnn.to_layout(qkv_bias, layout=ttnn.TILE_LAYOUT)

output_weight = ttnn.to_layout(output_weight, layout=ttnn.TILE_LAYOUT)

output_bias = ttnn.reshape(output_bias, (1, -1))
output_bias = ttnn.to_layout(output_bias, layout=ttnn.TILE_LAYOUT)

## Run the first iteration of the optimized Multi-Head Attention

In [13]:
start = time.time()
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

In [14]:
print(f"Optimized multi-head attention ran in {duration} seconds for the first iteration")

Optimized multi-head attention ran in 2.4207022190093994 seconds for the first iteration


## Run a subsequent iteration of the optimized Multi-Head Attention

In [15]:
start = time.time()
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

In [16]:
print(f"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Optimized multi-head attention ran in 0.0037069320678710938 seconds for the subsequent iteration because of the program cache


Note that the optimized multi-head attention is 2 orders of magnitude faster than the initial version

## Check that the output of the optimized version matches the output of the original implementation

In [17]:
torch_output = ttnn.to_torch(output)
torch_optimized_output = ttnn.to_torch(optimized_output)

assert torch.allclose(torch_output, torch_optimized_output)

## Close the device

In [18]:
ttnn.close_device(device)

2025-07-02 10:31:53.340 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)
2025-07-02 10:31:53.341 | info     |           Metal | Closing mesh device 0 (mesh_device.cpp:488)
2025-07-02 10:31:53.342 | info     |           Metal | Closing device 0 (device.cpp:469)
2025-07-02 10:31:53.342 | info     |           Metal | Disabling and clearing program cache on device 0 (device.cpp:781)
