In [2]:
import torch
#import ttnn
from torch import nn

torch.manual_seed(123)

<torch._C.Generator at 0x7e43787db3b0>

In [3]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((context, context), dim=0)
d = context.shape
d_in = context.shape[1]
d_out = context.shape[1] - 1

d_in, d_out

(3, 2)

In [4]:

class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    self.dropout = nn.Dropout(dropout)

    self.mask = torch.triu(
      torch.ones(context_length, context_length),
      diagonal=1
    )

  def forward(self, x): 
    b, num_tokens, d_in = x.shape
    keys = self.W_query(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    # transpose the last 2 dimensions while leaving the batch dimension alone.
    attn_scores = queries @ keys.transpose(1, 2)
    attn_scores.masked_fill(
      self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1]**0.5, dim=-1
    )

    attn_weights = self.dropout(attn_weights)

    context_vec = attn_weights @ values

    return context_vec


In [5]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
ca(batch)


tensor([[[-0.5287, -0.0976],
         [-0.5293, -0.1053],
         [-0.5293, -0.1052],
         [-0.5287, -0.1072],
         [-0.5287, -0.1038],
         [-0.5288, -0.1080]],

        [[-0.5287, -0.0976],
         [-0.5293, -0.1053],
         [-0.5293, -0.1052],
         [-0.5287, -0.1072],
         [-0.5287, -0.1038],
         [-0.5288, -0.1080]]], grad_fn=<UnsafeViewBackward0>)

In [6]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()

    self.heads = nn.ModuleList(
      [
        CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
        for _ in range(num_heads)
      ]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)



In [7]:
torch.manual_seed(123)
context_length = batch.shape[1]
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
context_vecs, context_vecs.shape

(tensor([[[-0.5287, -0.0976,  0.5122,  0.3448],
          [-0.5293, -0.1053,  0.5123,  0.3449],
          [-0.5293, -0.1052,  0.5121,  0.3448],
          [-0.5287, -0.1072,  0.5096,  0.3438],
          [-0.5287, -0.1038,  0.5078,  0.3427],
          [-0.5288, -0.1080,  0.5113,  0.3446]],
 
         [[-0.5287, -0.0976,  0.5122,  0.3448],
          [-0.5293, -0.1053,  0.5123,  0.3449],
          [-0.5293, -0.1052,  0.5121,  0.3448],
          [-0.5287, -0.1072,  0.5096,  0.3438],
          [-0.5287, -0.1038,  0.5078,  0.3427],
          [-0.5288, -0.1080,  0.5113,  0.3446]]], grad_fn=<CatBackward0>),
 torch.Size([2, 6, 4]))

In [8]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # Reduce the projection dimension to match the output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)

    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    
    keys = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)
    
    attn_scores = queries @ keys.transpose(2, 3)

    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores.masked_fill_(mask_bool, -torch.inf)

    print("attn_scores", attn_scores)

    attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)

    context_vec = torch.transpose((attn_weights @ values), 1, 2)

    context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

    return context_vec

In [9]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

context_vecs, context_vecs.shape


attn_scores tensor([[[[ 0.2029,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.1734,  0.2631,    -inf,    -inf,    -inf,    -inf],
          [ 0.1730,  0.2625,  0.2601,    -inf,    -inf,    -inf],
          [ 0.0777,  0.1179,  0.1168,  0.0648,    -inf,    -inf],
          [ 0.1178,  0.1787,  0.1770,  0.0983,  0.0973,    -inf],
          [ 0.0885,  0.1343,  0.1330,  0.0738,  0.0731,  0.0908]],

         [[ 0.1081,    -inf,    -inf,    -inf,    -inf,    -inf],
          [-0.0079, -0.0029,    -inf,    -inf,    -inf,    -inf],
          [-0.0063, -0.0023, -0.0025,    -inf,    -inf,    -inf],
          [-0.0267, -0.0099, -0.0104, -0.0005,    -inf,    -inf],
          [ 0.0237,  0.0088,  0.0092,  0.0004,  0.0148,    -inf],
          [-0.0409, -0.0151, -0.0159, -0.0008, -0.0254,  0.0058]]],


        [[[ 0.2029,    -inf,    -inf,    -inf,    -inf,    -inf],
          [ 0.1734,  0.2631,    -inf,    -inf,    -inf,    -inf],
          [ 0.1730,  0.2625,  0.2601,    -inf,    -inf,   

(tensor([[[-0.4519,  0.2216],
          [-0.5889,  0.0122],
          [-0.6313, -0.0576],
          [-0.5685, -0.0832],
          [-0.5541, -0.0964],
          [-0.5311, -0.1077]],
 
         [[-0.4519,  0.2216],
          [-0.5889,  0.0122],
          [-0.6313, -0.0576],
          [-0.5685, -0.0832],
          [-0.5541, -0.0964],
          [-0.5311, -0.1077]]], grad_fn=<ViewBackward0>),
 torch.Size([2, 6, 2]))

## TTNN Example

In [10]:
import ttnn

torch.manual_seed(123)

2025-05-04 13:54:01.144 | DEBUG    | ttnn.library_tweaks:prepare_dir_as_metal_home:54 - Existing installation of 0.57.0rc60+any detected
2025-05-04 13:54:01.168 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


<torch._C.Generator at 0x7e43787db3b0>

In [11]:
core_grid_y = 8
core_grid_x = 8
MINUS_INFINITY=-1e9

class MultiHeadAttention_ttnn(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, device, qkv_bias=False):
    super().__init__()

    self.device = device
    self.dropout_prob = dropout

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # Reduce the projection dimension to match the output dim

    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    self.mask_ttnn = ttnn.from_torch(
      self.mask,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape

    x_ttnn = ttnn.from_torch(
      x,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device
    )

    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    
    keys_ttnn = ttnn.reshape(keys_ttnn, (b, num_tokens, self.num_heads, self.head_dim))
    values_ttnn = ttnn.reshape(values_ttnn, (b, num_tokens, self.num_heads, self.head_dim))
    queries_ttnn = ttnn.reshape(queries_ttnn, (b, num_tokens, self.num_heads, self.head_dim))

    keys_ttnn = ttnn.permute(keys_ttnn, (0, 2, 1, 3))
    values_ttnn = ttnn.permute(values_ttnn, (0, 2, 1, 3))
    queries_ttnn = ttnn.permute(queries_ttnn, (0, 2, 1, 3))

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn, 
      ttnn.permute(keys_ttnn, (0, 1, 3, 2)),
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )


    # if self.dropout_prob > 0.0:
    inf_ttnn = ttnn.full_like(
      attn_scores_ttnn,
      MINUS_INFINITY, 
      layout=ttnn.TILE_LAYOUT
    )

    #attn_scores_ttnn = ttnn.where(self.mask_ttnn[:num_tokens, :num_tokens], inf_ttnn, attn_scores_ttnn)
    attn_scores_ttnn = attn_scores_ttnn * (1 / (self.head_dim ** 0.5))
    attn_mask_ttnn = ttnn.from_torch(
      torch.randn(b, 1, 1, num_tokens),
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self.device,
    )
    attn_scores_ttnn += attn_mask_ttnn

    print("attn_scores_ttnn", ttnn.to_layout(attn_scores_ttnn, ttnn.ROW_MAJOR_LAYOUT))

    attn_weights_ttnn = ttnn.softmax(attn_scores_ttnn, dim=-1)

    if self.dropout_prob > 0.0:
      attn_weights_ttnn = ttnn.experimental.dropout(
        attn_weights_ttnn,
        seed=123,
        probability=self.dropout_prob,
        scale=1.0 / (1.0 - self.dropout_prob)
      )

    
    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    ) 

    context_vec_ttnn = ttnn.permute(context_vec_ttnn, (0, 2, 1, 3))

    
    context_vec_ttnn = ttnn.reshape(context_vec_ttnn, (b, num_tokens, self.d_out))

    context_vec = ttnn.from_device(context_vec_ttnn)
    context_vec = ttnn.to_torch(context_vec)

    return context_vec

In [12]:

torch.manual_seed(123)

device_id = 0
device = ttnn.open_device(device_id=device_id)

batch_size, context_length, d_in = batch.shape
d_out = 32
mha = MultiHeadAttention_ttnn(d_in, d_out, context_length, 0.0, num_heads=2, device=device)

context_vecs = mha(batch)

ttnn.close_device(device)

context_vecs, context_vecs.shape

                 Device | INFO     | Opening user mode device driver
[32m2025-05-04 13:54:02.969[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled

[32m2025-05-04 13:54:02.977[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-05-04 13:54:02.978[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Harvesting mask for chip 0 is 0x200 (physical layout: 0x1, logical: 0x200, simulated harvesting mask: 0x0).
[32m2025-05-04 13:54:02.979[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-05-04 13:54:02.979[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-05-04 13:54:02.979[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: 

New chip! We now have 1 chips
Chip initialization complete (found )
Chip initializing complete...
 ARC

 [4/4] DRAM

 [16/16] ETH

 CPU

Chip detection complete (found )



attn_scores_ttnn ttnn.Tensor([[[[ 0.11279, -0.62109,  ..., -1.46094,  0.98047],
               [ 0.15625, -0.57812,  ..., -1.41406,  0.99219],
               ...,
               [ 0.12256, -0.61719,  ..., -1.44531,  0.97266],
               [ 0.16895, -0.62109,  ..., -1.41406,  0.94531]],

              [[-0.04395, -0.96484,  ..., -1.46094,  0.62109],
               [-0.13574, -1.07031,  ..., -1.54688,  0.57031],
               ...,
               [-0.06738, -0.98438,  ..., -1.54688,  0.66016],
               [-0.01367, -0.89062,  ..., -1.50000,  0.72656]]],

             [[[ 0.60156, -0.89453,  ..., -0.50391,  0.22852],
               [ 0.64453, -0.85156,  ..., -0.45898,  0.24121],
               ...,
               [ 0.61328, -0.89062,  ..., -0.49219,  0.22168],
               [ 0.66016, -0.89453,  ..., -0.46289,  0.19238]],

              [[ 0.44531, -1.23438,  ..., -0.50391, -0.13281],
               [ 0.35352, -1.34375,  ..., -0.59375, -0.18359],
               ...,
             

(TorchTensor([[[ 0.1079, -0.1455,  0.1182,  0.0165, -0.1001,  0.1318,  0.0457,
                -0.0620, -0.1128, -0.4609,  0.1270, -0.5430, -0.0072, -0.2031,
                -0.1270, -0.1729, -0.1816,  0.5039, -0.0374,  0.1348, -0.1641,
                -0.1504,  0.2734,  0.0986,  0.3691,  0.1768,  0.2080, -0.3848,
                -0.3359, -0.1162, -0.0500,  0.1504],
               [ 0.1084, -0.1494,  0.1206,  0.0160, -0.1030,  0.1348,  0.0461,
                -0.0645, -0.1123, -0.4668,  0.1309, -0.5547, -0.0066, -0.2041,
                -0.1270, -0.1738, -0.1826,  0.5039, -0.0364,  0.1338, -0.1660,
                -0.1504,  0.2715,  0.0986,  0.3672,  0.1768,  0.2090, -0.3848,
                -0.3359, -0.1147, -0.0515,  0.1494],
               [ 0.1074, -0.1475,  0.1196,  0.0159, -0.1021,  0.1328,  0.0457,
                -0.0640, -0.1113, -0.4609,  0.1289, -0.5469, -0.0065, -0.2021,
                -0.1260, -0.1719, -0.1836,  0.5039, -0.0361,  0.1338, -0.1670,
                -0.1514, 

In [13]:
ttnn.close_device(device)

                 Always | FATAL    | Attempting to push work to Device 0 which is not initialized. Ignoring...
