# Multi-Head Attention
The last part of attention. This is going to be a lot of code. Brace yourselves.

![Brace yourself](./img/brace.jpg)

This is going to be t 

In [3]:
import torch
from torch import nn

torch.manual_seed(123)

<torch._C.Generator at 0x7443e0291030>

In [4]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((context, context), dim=0)
d = context.shape
d_in = context.shape[1]
d_out = context.shape[1] - 1

d_in, d_out

(3, 2)

In [5]:

class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    self.dropout = nn.Dropout(dropout)

    self.mask = torch.triu(
      torch.ones(context_length, context_length),
      diagonal=1
    )

  def forward(self, x): 
    b, num_tokens, d_in = x.shape
    keys = self.W_query(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    # transpose the last 2 dimensions while leaving the batch dimension alone.
    attn_scores = queries @ keys.transpose(1, 2)
    attn_scores.masked_fill(
      self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1]**0.5, dim=-1
    )

    attn_weights = self.dropout(attn_weights)

    context_vec = attn_weights @ values

    return context_vec


In [6]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
ca(batch)


tensor([[[-0.5287, -0.0976],
         [-0.5293, -0.1053],
         [-0.5293, -0.1052],
         [-0.5287, -0.1072],
         [-0.5287, -0.1038],
         [-0.5288, -0.1080]],

        [[-0.5287, -0.0976],
         [-0.5293, -0.1053],
         [-0.5293, -0.1052],
         [-0.5287, -0.1072],
         [-0.5287, -0.1038],
         [-0.5288, -0.1080]]], grad_fn=<UnsafeViewBackward0>)

In [7]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()

    self.heads = nn.ModuleList(
      [
        CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
        for _ in range(num_heads)
      ]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)



In [8]:
torch.manual_seed(123)
context_length = batch.shape[1]
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
context_vecs, context_vecs.shape

(tensor([[[-0.5287, -0.0976,  0.5122,  0.3448],
          [-0.5293, -0.1053,  0.5123,  0.3449],
          [-0.5293, -0.1052,  0.5121,  0.3448],
          [-0.5287, -0.1072,  0.5096,  0.3438],
          [-0.5287, -0.1038,  0.5078,  0.3427],
          [-0.5288, -0.1080,  0.5113,  0.3446]],
 
         [[-0.5287, -0.0976,  0.5122,  0.3448],
          [-0.5293, -0.1053,  0.5123,  0.3449],
          [-0.5293, -0.1052,  0.5121,  0.3448],
          [-0.5287, -0.1072,  0.5096,  0.3438],
          [-0.5287, -0.1038,  0.5078,  0.3427],
          [-0.5288, -0.1080,  0.5113,  0.3446]]], grad_fn=<CatBackward0>),
 torch.Size([2, 6, 4]))

In [9]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # Reduce the projection dimension to match the output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)

    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    
    keys = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)
    
    attn_scores = queries @ keys.transpose(2, 3)

    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores.masked_fill_(mask_bool, -torch.inf)

    print("attn_scores", attn_scores)

    attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vec = torch.transpose((attn_weights @ values), 1, 2)

    context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

    return context_vec

In [17]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

context_vecs, context_vecs.shape


attn_scores tensor([[[[ 0.2477,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.1301,  0.2214,    -inf,    -inf,    -inf,    -inf],
          [ 0.1311,  0.2198,  0.2194,    -inf,    -inf,    -inf],
          [ 0.0395,  0.1158,  0.1157,  0.0668,    -inf,    -inf],
          [ 0.1120,  0.1295,  0.1291,  0.0646,  0.0858,    -inf],
          [ 0.0365,  0.1395,  0.1394,  0.0821,  0.0989,  0.0875]],

         [[-0.0366,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.2326, -0.2530,    -inf,    -inf,    -inf,    -inf],
          [-0.2270, -0.2472, -0.2359,    -inf,    -inf,    -inf],
          [-0.1564, -0.1675, -0.1606, -0.1001,    -inf,    -inf],
          [-0.0618, -0.0731, -0.0682, -0.0500,  0.0386,    -inf],
          [-0.2178, -0.2314, -0.2223, -0.1366,  0.0035, -0.2325]]],


        [[[ 0.2477,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.1301,  0.2214,    -inf,    -inf,    -inf,    -inf],
          [ 0.1311,  0.2198,  0.2194,    -inf,    -inf,   

(tensor([[[-0.3132, -0.2272,  0.4772,  0.1063],
          [-0.2308,  0.0329,  0.5764,  0.3007],
          [-0.2059,  0.1190,  0.6097,  0.3654],
          [-0.1642,  0.1340,  0.5431,  0.3503],
          [-0.1689,  0.1794,  0.5296,  0.3389],
          [-0.1407,  0.1699,  0.5040,  0.3403]],
 
         [[-0.3132, -0.2272,  0.4772,  0.1063],
          [-0.2308,  0.0329,  0.5764,  0.3007],
          [-0.2059,  0.1190,  0.6097,  0.3654],
          [-0.1642,  0.1340,  0.5431,  0.3503],
          [-0.1689,  0.1794,  0.5296,  0.3389],
          [-0.1407,  0.1699,  0.5040,  0.3403]]], grad_fn=<ViewBackward0>),
 torch.Size([2, 6, 4]))

## TTNN Example

In [11]:
import ttnn

torch.manual_seed(123)

2025-05-04 20:53:37.307 | DEBUG    | ttnn.library_tweaks:prepare_dir_as_metal_home:54 - Existing installation of 0.57.0rc60+any detected


2025-05-04 20:53:37.331 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


<torch._C.Generator at 0x7443e0291030>

In [None]:
core_grid_y = 8
core_grid_x = 8
MINUS_INFINITY=-1e9

#ttnn.set_printoptions(profile="short")

class MultiHeadAttention_ttnn(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, device, qkv_bias=False):
    super().__init__()

    self.device = device
    self.dropout_prob = dropout

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # Reduce the projection dimension to match the output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.mask_ttnn = ttnn.from_torch(
      self.mask,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )

    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    
    keys_ttnn = ttnn.reshape(keys_ttnn, (b, num_tokens, self.num_heads, self.head_dim))
    values_ttnn = ttnn.reshape(values_ttnn, (b, num_tokens, self.num_heads, self.head_dim))
    queries_ttnn = ttnn.reshape(queries_ttnn, (b, num_tokens, self.num_heads, self.head_dim))

    keys_transposed_ttnn = ttnn.permute(keys_ttnn, (0, 2, 3, 1))
    values_ttnn = ttnn.permute(values_ttnn, (0, 2, 1, 3))
    queries_ttnn = ttnn.permute(queries_ttnn, (0, 2, 1, 3))

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn, 
      keys_transposed_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    attn_scores_ttnn = attn_scores_ttnn * (1 / (self.head_dim ** 0.5))

    attn_mask_ttnn = self.mask_ttnn[:num_tokens, :num_tokens] * MINUS_INFINITY
    attn_mask_ttnn = ttnn.reshape(attn_mask_ttnn, (1, 1, num_tokens, num_tokens))
    attn_scores_ttnn += attn_mask_ttnn
    
    attn_weights_ttnn = ttnn.softmax(attn_scores_ttnn, dim=-1)

    if self.dropout_prob > 0.0:
      attn_weights_ttnn = ttnn.experimental.dropout(
        attn_weights_ttnn,
        seed=123,
        probability=self.dropout_prob,
        scale=1.0 / (1.0 - self.dropout_prob)
      )

    
    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    ) 

    context_vec_ttnn = ttnn.permute(context_vec_ttnn, (0, 2, 1, 3))

    
    context_vec_ttnn = ttnn.reshape(context_vec_ttnn, (b, num_tokens, self.d_out))

    context_vec = ttnn.from_device(context_vec_ttnn)
    context_vec = ttnn.to_torch(context_vec)

    return context_vec

In [18]:

torch.manual_seed(123)

device_id = 0
device = ttnn.open_device(device_id=device_id)

batch_size, context_length, d_in = batch.shape
d_out =4
mha = MultiHeadAttention_ttnn(d_in, d_out, context_length, 0.0, num_heads=2, device=device)

context_vecs = mha(batch)

ttnn.close_device(device)

context_vecs, context_vecs.shape

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[[-0.2969, -0.2266,  0.4570,  0.0991],
               [-0.2139,  0.0289,  0.5586,  0.2910],
               [-0.1914,  0.1099,  0.5898,  0.3535],
               [-0.1533,  0.1245,  0.5156,  0.3340],
               [-0.1572,  0.1660,  0.5117,  0.3281],
               [-0.1309,  0.1582,  0.4824,  0.3301]],
 
              [[-0.2969, -0.2266,  0.4570,  0.0991],
               [-0.2139,  0.0289,  0.5586,  0.2910],
               [-0.1914,  0.1099,  0.5898,  0.3535],
               [-0.1533,  0.1245,  0.5156,  0.3340],
               [-0.1572,  0.1660,  0.5117,  0.3281],
               [-0.1309,  0.1582,  0.4824,  0.3301]]], dtype=torch.bfloat16),
 torch.Size([2, 6, 4]))

In [39]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0
