# Multi-Head Attention

Multi-Head Attention is an important part of all Transformer-based models.
This tutorial will show how to write it and how to then optimize it.

In [1]:
import time
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
device = ttnn.open(device_id)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0
[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver
[32m2023-12-05 03:37:48.916[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device
[32m2023-12-05 03:37:48.927[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (pci_interface_id: 0 device_id: 0xfaca revision: 0)
[0;33m---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
[0m[32m2023-12-05 03:37:49.023[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Disable PCIE DMA
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for d

## Enable program cache

In [2]:
ttnn.enable_program_cache()

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program Cache: enabled.


## Write Multi-Head Attention using ttnn

Multi-head can be implemented in `torch` using just 6 operations:
1. `torch.matmul`
2. `torch.add`
3. `torch.reshape`
4. `torch.permute`
5. `torch.mul`
6. `torch.softmax`

`ttnn` provides the exact same APIs to do that and therefore multi-head attention can be implemented as shown below:

In [3]:
def multi_head_attention(
    hidden_states,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    num_heads,
):
    batch_size, sequence_size, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.reshape(context_layer, (batch_size, sequence_size, hidden_size))

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

Now that the model is written, let's create input tensors to run it and test it

## Configuration

In [4]:
batch_size = 8
sequence_size = 384
num_heads = 16
head_size = 64
hidden_size = num_heads * head_size

## Initialize activations and weights using torch

In [5]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)
torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)

## Convert activations and weights to ttnn

In [6]:
hidden_states = ttnn.from_torch(torch_hidden_states)
query_weight = ttnn.from_torch(torch_query_weight)
query_bias = ttnn.from_torch(torch_query_bias)
key_weight = ttnn.from_torch(torch_key_weight)
key_bias = ttnn.from_torch(torch_key_bias)
value_weight = ttnn.from_torch(torch_value_weight)
value_bias = ttnn.from_torch(torch_value_bias)
output_weight = ttnn.from_torch(torch_output_weight)
output_bias = ttnn.from_torch(torch_output_bias)

## Move activations and weights to device

In [7]:
hidden_states = ttnn.to_device(hidden_states, device)
query_weight = ttnn.to_device(query_weight, device)
query_bias = ttnn.to_device(query_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.to_device(key_weight, device)
key_bias = ttnn.to_device(key_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.to_device(value_weight, device)
value_bias = ttnn.to_device(value_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)

## Run the first iteration of Multi-Head Attention

In [8]:
start = time.time()
multi_head_attention(
    hidden_states,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Tilize                               finished in     0.442773428 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Tilize                               finished in     0.283801031 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Matmul                               finished in     0.491539668 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::TilizeWithValPadding                 finished in     0.457688582 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::EltwiseBinaryBroadcast               finished in     0.448720675 seconds
[38;2;000;128;000m      

In [9]:
print(f"Multi-head attention ran in {duration} seconds for the first iteration")

Multi-head attention ran in 7.927839279174805 seconds for the first iteration


## Run a subsequent iteration of Multi-Head Attention

In [10]:
start = time.time()
output = multi_head_attention(
    hidden_states,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Tilize                               finished in     0.005479948 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Tilize                               finished in     0.001622081 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Matmul                               finished in     0.001215313 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::TilizeWithValPadding                 finished in     0.000957504 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::EltwiseBinaryBroadcast               finished in     0.001032264 seconds
[38;2;000;128;000m      

In [11]:
print(f"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Multi-head attention ran in 0.6879308223724365 seconds for the subsequent iteration because of the program cache


## Write optimized version of Multi-Head Attention

Optimized version of the multi-head attention can be written by:
- Tilizing all of the tensors ahead of time
- Using more performant matmuls that fuse bias and specify the number of cores they execute on
- Putting every tensor into L1
- Using bfloat8_b data_type
- Using custom `nlp` operations instead of `ttnn.permute` and `ttnn.reshape`

`ttnn.deallocate` calls are needed because otherwise, the cores on the device will run out of the L1 memory

In [12]:
def optimized_multi_head_attention(
    hidden_states,
    fused_qkv_weight,
    fused_qkv_bias,
    self_output_weight,
    self_output_bias,
    *,
    num_heads,
    num_cores_x=12,
):
    batch_size, _, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads
    
    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)

    fused_qkv_output = ttnn.linear(
        hidden_states,
        fused_qkv_weight,
        bias=fused_qkv_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=(batch_size, num_cores_x),
    )

    (
        query,
        key,
        value,
    ) = ttnn.transformer.split_query_key_value_and_split_heads(
        fused_qkv_output,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        num_heads=num_heads,
    )
    ttnn.deallocate(fused_qkv_output)

    attention_scores = ttnn.matmul(
        query,
        key,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=(batch_size, num_cores_x),
    )
    ttnn.deallocate(query)
    ttnn.deallocate(key)

    attention_probs = ttnn.transformer.attention_softmax(attention_scores, attention_mask=None, head_size=head_size)

    context_layer = ttnn.matmul(
        attention_probs,
        value,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=(batch_size, num_cores_x),
    )
    ttnn.deallocate(attention_probs)

    context_layer = ttnn.transformer.concatenate_heads(
        context_layer,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )

    self_output = ttnn.linear(
        context_layer,
        self_output_weight,
        bias=self_output_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=(batch_size, num_cores_x),
    )
    ttnn.deallocate(context_layer)

    return self_output

## Pre-process the parameters of the optimized model

1. Fuse QKV weights and biases
2. Reshape and tilize for the optimized operations using preprocess_linear_weight and preprocess_linear_bias
3. Move to device

In [13]:
from ttnn.model_preprocessing import (
    preprocess_linear_bias,
    preprocess_linear_weight,
)

torch_qkv_weight = torch.cat([torch_query_weight, torch_key_weight, torch_value_weight], dim=-1)
torch_qkv_bias = torch.cat([torch_query_bias, torch_key_bias, torch_value_bias], dim=-1)

qkv_weight = preprocess_linear_weight(torch_qkv_weight.T, dtype=ttnn.bfloat16)
qkv_bias = preprocess_linear_bias(torch_qkv_bias, dtype=ttnn.bfloat16)
output_weight = preprocess_linear_weight(torch_output_weight.T, dtype=ttnn.bfloat16)
output_bias = preprocess_linear_bias(torch_output_bias, dtype=ttnn.bfloat16)

qkv_weight = ttnn.to_device(qkv_weight, device)
qkv_bias = ttnn.to_device(qkv_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)

## Run the first iteration of the optimized Multi-Head Attention

In [14]:
start = time.time()
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
optimized_output = optimized_multi_head_attention(
    hidden_states,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Tilize                               finished in     0.005119581 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::Matmul                    finished in     0.650794504 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::transformers::SplitFusedQKVAndSplitHeads finished in     0.478532563 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::Matmul                    finished in     0.612131647 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::EltwiseBinaryBroadcast               finished in     0.385793305 seconds
[38;2;000

In [15]:
print(f"Optimized multi-head attention ran in {duration} seconds for the first iteration")

Optimized multi-head attention ran in 3.7684011459350586 seconds for the first iteration


## Run a subsequent iteration of the optimized Multi-Head Attention

In [16]:
start = time.time()
optimized_output = optimized_multi_head_attention(
    hidden_states,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::Matmul                    finished in     0.000593846 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::transformers::SplitFusedQKVAndSplitHeads finished in     0.000213378 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::Matmul                    finished in     0.000458387 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::EltwiseBinaryBroadcast               finished in     0.001472872 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::operations::primary::Softmax                   finished in     0.001904859 seconds
[38;2;000

In [17]:
print(f"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Optimized multi-head attention ran in 0.008018732070922852 seconds for the subsequent iteration because of the program cache


Note that the optimized multi-head attention is 2 orders of magnitude faster than the initial version

## Check that the output of the optimized version matches the output of the original implementation

In [18]:
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)
output = ttnn.from_device(output)
torch_output = ttnn.to_torch(output)

optimized_output = ttnn.to_layout(optimized_output, ttnn.ROW_MAJOR_LAYOUT)
optimized_output = ttnn.from_device(optimized_output)
torch_optimized_output = ttnn.to_torch(optimized_output)

assert torch.allclose(torch_output, torch_optimized_output)

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Untilize                             finished in     0.006022145 seconds
[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program of Operation tt::tt_metal::Untilize                             finished in     0.367331572 seconds


## Close the device

In [19]:
ttnn.close(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
