# Simple Weighted Attention Mechanism V2

We continue on with implementing the Simple Weighted Attention Mechanism. We'll call this a "V2". It uses `torch` linear layers, the beginnings in demonstrating a neural network.

To keep things interesting, we'll dive deeper into understanding the existence of the Tenstorrent hardware. For example, what is the point of having a TPU when we have GPUs? How much faster is a Wormhole compared to just using a very powerful CPU? In this case, since I have a Wormhole n150d, the context of the discussions will be around that.

It's really hard to appreciate what the Tenstorrent hardware can do if you're only used to experiencing performance in one perspective. For example, if you are used to doing work with GPU acceleration, you may be going "what's the point?" after seeing some of the acceleration that the hardware provides. 

Anyway we'll look closer to all of this and more in this project. 

Like with all our other implementations, we will start with the `torch` version. Again, like all other notebooks, the code is inspired by Sebastian Raschka's LLM From Scratch code.

## Import Libraries

Like before we will first start by importing the necessary libraries.

We also set the manual seed for `torch` to be `789`. Sebastian chooses this as the seed in his projects, so I will just use that.


In [1]:
import torch
from torch import nn

torch.manual_seed(789)

<torch._C.Generator at 0x733eaa138050>

## SelfAttention_v2 Class

Next, we will implement an improved version of the self attention mechansim from the previous notebook. This time, we will generalize the implementation to a Python class called `SimpleAttention_v2`. 

Usage is simple, it is a `torch` module that simply does a forward pass of given inputs. For example, it is used like this:

```python
# x is a tensor to be treated as input
result = SelfAttention_v2(x)
```

In [None]:
class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    
    self.W_query = nn.Linear(d_in, d_out, bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1] ** 0.5,
      dim=-1
    )
    context_vec = attn_weights @ values

    return context_vec

NameError: name 'nn' is not defined

Let's test this class in CPU by initializing the context we have been using, creating the `SelfAttention_v2` instance, and then evaluating by performing a "forward pass" as-is.

In [3]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d = context[0].shape
d_in = d[0]
d_out = d[0] - 1

sa_v2 = SelfAttention_v2(d_in, d_out)

context_vec = sa_v2(context)
context_vec

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

## Making it work with Tenstorrent

the tricky part here is to figure out _what_ to port to using the Tenstorrent hardware. 

With our current skillset (at least mine), I don't know how to do _everything_ from scratch in `ttnn` land, but we can at least accelerate some of the compute by offloading some tensor calculations within the `SelfAttention_v2` class. 

This means that the forward pass method is a great candidate.

```python
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
context_vec = attn_weights @ values
```

We can send the keys, queries, values tensors to hardware, and then perform a `ttnn.matmul` on them. The goal is to offload as much computation as possible to hardware.

At the same time, we have to minimize the amount of data transfers between CPU and device memory as that can destroy any benefits in computing on the Tenstorrent hardware.

In order to understand if it helps or hurts, we're going to create a small benchmark with a lot of random matrices and just forward pass repeatedly to see the difference in CPU vs Tenstorrent hardware for this type of computation.

Let's start out with configuring the device.

In [4]:
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

2025-04-27 13:13:28.073 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}
New chip! We now have 1 chips


                 Device | INFO     | Opening user mode device driver
[32m2025-04-27 13:13:28.146[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled

[32m2025-04-27 13:13:28.162[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-04-27 13:13:28.164[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Harvesting mask for chip 0 is 0x200 (physical layout: 0x1, logical: 0x200, simulated harvesting mask: 0x0).
[32m2025-04-27 13:13:28.165[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-04-27 13:13:28.165[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-04-27 13:13:28.165[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: 

Chip initialization complete (found )
Chip initializing complete...
 ARC

 [4/4] DRAM

 [16/16] ETH

 CPU

Chip detection complete (found )


Once we have that, let's line by line translate the forward pass of the self attention mechanism to `ttnn`.

First line:

```python
attn_scores = queries @ keys.T
```

This involves:
1. Assigning queries to be the compute of input with `W_query`.
2. Transpose the keys (compute of input with `W_key`)
3. Create the `ttnn` tensors from queries and keys transposed.
4. Compute the matmul of queries and keys transposed.


The two results betwen computing CPU vs hardware should be almost the same (they get close due to precision)


In [5]:

queries = sa_v2.W_query(context)
keys = sa_v2.W_key(context)
keys_transposed = keys.T


queries_ttnn = ttnn.from_torch(
  queries,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

keys_transposed_ttnn = ttnn.from_torch(
  keys_transposed,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

attn_scores_ttnn = ttnn.matmul(
  queries_ttnn,
  keys_transposed_ttnn
)

# For comparison
attn_scores_cpu = queries @ keys.T 
 
attn_scores_ttnn, attn_scores_cpu




(ttnn.Tensor([[ 0.28906,  0.07178,  ...,  0.13379, -0.04980],
              [ 0.46484,  0.17090,  ...,  0.17676,  0.00867],
              ...,
              [ 0.21777,  0.08691,  ...,  0.07812,  0.01489],
              [ 0.34180,  0.12598,  ...,  0.12988,  0.00720]], shape=Shape([6, 6]), dtype=DataType::BFLOAT16, layout=Layout::TILE),
 tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
         [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
         [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
         [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
         [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
         [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
        grad_fn=<MmBackward0>))

Next, it gets pretty tricky, we have to deal with the softmax. Notice that we can't do a scalar divide against a tensor in ttnn.

```python
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
```

Here is what needs to happen

1. We will need to scale ahead of time, a torch tensor of the attention scores and the scale value. This means that we have to bring the attn_scores_ttnn back to CPU memory to do that. 
2. Send back the result to hardware.
3. Perform the softmax on the last dimension.

Again, we should get pretty close to the CPU version.

In [6]:
attn_scores_scaled_torch = ttnn.to_torch(attn_scores_ttnn) / keys.shape[-1] ** 0.5
attn_scores_scaled_ttnn = ttnn.from_torch(
  attn_scores_scaled_torch,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

attn_weights_ttnn = ttnn.softmax(attn_scores_scaled_ttnn, dim=-1)

# CPU version
attn_weights = torch.softmax(attn_scores_scaled_torch, dtype=torch.float32, dim=-1)

attn_weights_ttnn, attn_weights



(ttnn.Tensor([[ 0.19629,  0.16211,  ...,  0.17188,  0.15039],
              [ 0.20898,  0.16504,  ...,  0.16602,  0.14355],
              ...,
              [ 0.18652,  0.16602,  ...,  0.16309,  0.15625],
              [ 0.19922,  0.16602,  ...,  0.16602,  0.14941]], shape=Shape([6, 6]), dtype=DataType::BFLOAT16, layout=Layout::TILE),
 TorchTensor([[0.1920, 0.1647, 0.1652, 0.1550, 0.1721, 0.1511],
              [0.2040, 0.1658, 0.1663, 0.1496, 0.1665, 0.1478],
              [0.2035, 0.1659, 0.1662, 0.1499, 0.1662, 0.1482],
              [0.1867, 0.1667, 0.1668, 0.1571, 0.1661, 0.1565],
              [0.1830, 0.1668, 0.1670, 0.1588, 0.1658, 0.1585],
              [0.1936, 0.1662, 0.1665, 0.1542, 0.1667, 0.1529]]))

Finally, we can compute the context vector. This involves another matmul on the attention weights and values.

```python
context_vec = attn_weights @ values
```

In [7]:
values = sa_v2.W_value(context)

values_ttnn = ttnn.from_torch(
  values,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

context_vec_ttnn = ttnn.matmul(attn_weights_ttnn, values_ttnn)

context_vec_cpu = attn_weights @ values

context_vec_ttnn, context_vec_cpu

(ttnn.Tensor([[-0.07373,  0.07080],
              [-0.07373,  0.07080],
              ...,
              [-0.07617,  0.06689],
              [-0.07471,  0.06982]], shape=Shape([6, 2]), dtype=DataType::BFLOAT16, layout=Layout::TILE),
 TorchTensor([[-0.0739,  0.0712],
              [-0.0748,  0.0703],
              [-0.0749,  0.0702],
              [-0.0760,  0.0684],
              [-0.0763,  0.0679],
              [-0.0754,  0.0693]], grad_fn=<AliasBackward0>))

Close the device

In [8]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


It's a lot more verbose. Let's put it together in a revised SelfAttentionV2 class. We will just need to include an additional device parameter so we can internally move things around.

In [9]:
import torch
import ttnn

torch.manual_seed(789)


class SelfAttention_ttnn_v2(nn.Module):
  def __init__(self, d_in, d_out, device):
    super().__init__()
    
    self.W_query = nn.Linear(d_in, d_out, bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

    self._device = device

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    queries_ttnn = ttnn.from_torch(
      queries,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )

    keys_transposed_ttnn = ttnn.from_torch(
      keys.T,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )

    values_ttnn = ttnn.from_torch(
      values,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn,
      keys_transposed_ttnn
    )

    attn_scores_scaled_torch = ttnn.to_torch(attn_scores_ttnn) / (keys.shape[-1] ** 0.5)
    attn_scores_scaled_ttnn = ttnn.from_torch(
      attn_scores_scaled_torch,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )
    attn_weights_ttnn = ttnn.softmax(attn_scores_scaled_ttnn, dim=-1)

    context_vec_ttnn = ttnn.matmul(attn_weights_ttnn, values_ttnn)

    context_vec = ttnn.to_torch(context_vec_ttnn)

    return context_vec

In [10]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d = context[0].shape
d_in = d[0]
d_out = d[0] - 1

device_id = 0
device = ttnn.open_device(device_id=device_id)

sa_ttnn_v2 = SelfAttention_ttnn_v2(d_in, d_out, device)
context_vec = sa_ttnn_v2(context)

print(context_vec)

ttnn.close_device(device)



                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
TorchTensor([[-0.0737,  0.0708],
             [-0.0737,  0.0708],
             [-0.0737,  0.0708],
             [-0.0757,  0.0684],
             [-0.0762,  0.0669],
             [-0.0747,  0.0698]], dtype=torch.bfloat16)
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


In [11]:
import time

start = time.process_time_ns()

print("hello")

end = time.process_time_ns()

print(f"Time: {(end - start) / 1000000:.4f} milliseconds")

hello
Time: 0.0992 milliseconds


In [12]:
import time

class PerfTimer:
  def __init__(self):
    self.start_time = 0
    self.end_time = 0
    
  def start(self):
    self.start_time = time.time()
    
  def stop(self):
    self.end_time = time.time()

  def reset(self):
    self.start_time = 0
    self.end_time = 0
    
  def elapsed_ms(self):
    return (self.end_time - self.start_time) * 1000

In [13]:
torch.manual_seed(789)

t = PerfTimer()

t.start()
torch_tensors = torch.stack([torch.randn(1024, 2048) for _ in range(0, 10000)])
t.stop()

torch_tensors.shape, t.elapsed_ms()

(torch.Size([10000, 1024, 2048]), 81634.54627990723)

In [14]:
torch.manual_seed(789)
torch.set_printoptions(sci_mode=False)

t.reset()

t.start()
sa_v2 = SelfAttention_v2(2048, 2048)
for tensor in torch_tensors:
  result = sa_v2(tensor)
t.stop()

result, t.elapsed_ms()

(tensor([[-0.0203, -0.0032,  0.0253,  ..., -0.0179, -0.0345,  0.0475],
         [-0.0090,  0.0062,  0.0335,  ..., -0.0040, -0.0133,  0.0238],
         [-0.0138,  0.0054,  0.0356,  ..., -0.0162, -0.0368,  0.0337],
         ...,
         [-0.0289, -0.0172,  0.0434,  ..., -0.0038, -0.0330,  0.0301],
         [-0.0309, -0.0027,  0.0365,  ..., -0.0138, -0.0279,  0.0271],
         [-0.0204,  0.0005,  0.0384,  ..., -0.0225, -0.0309,  0.0429]],
        grad_fn=<MmBackward0>),
 116655.43222427368)

In [85]:
torch.manual_seed(789)
torch.set_printoptions(sci_mode=False)

t.reset()

device_id = 0
device = ttnn.open_device(device_id=device_id)

t.start()
sa_ttnn_v2 = SelfAttention_ttnn_v2(2048, 2048, device)
for tensor in torch_tensors:
  result = sa_ttnn_v2(tensor)

t.stop()

ttnn.close_device(device)

result, t.elapsed_ms()


                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[    -0.0258,      0.0009,     -0.0001,  ...,     -0.0284,
                    0.0184,     -0.0339],
              [    -0.0352,      0.0064,      0.0037,  ...,     -0.0195,
                    0.0128,     -0.0422],
              [    -0.0270,      0.0104,     -0.0033,  ...,     -0.0183,
                    0.0161,     -0.0437],
              ...,
              [    -0.0289,      0.0104,      0.0033,  ...,     -0.0242,
                    0.0034,     -0.0549],
              [    -0.0153,     -0.0067,     -0.0028,  ...,     -0.0292,
                    0.0149,     -0.0310],
              [    -0.0332,      0.0133,      0.0093,  ...,     -0.0297,
                    0.0131,     -0.0275]], dtype=torch.bfloat16),
 4782.804250717163)

In [15]:
import torch
import ttnn

torch.manual_seed(789)

class SelfAttention_ttnn_opt_v2(nn.Module):
  def __init__(self, d_in, d_out, device):
    super().__init__()

    self.W_query = nn.Linear(d_in, d_out , bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

    self._device = device
    
    # Extract weight matrices from PyTorch layers and convert to TTNN once
    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    self._scaler = 1 / (d_out ** 0.5)

  def forward(self, x):
    x_ttnn = ttnn.from_torch(
      x, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=device,
    )
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=32, x=32)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=32, x=32)
    )
    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=32, x=32)
    )

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn, 
      ttnn.permute(keys_ttnn, (1, 0)),
      core_grid=ttnn.CoreGrid(y=32, x=32)
    )

    attn_weights_ttnn = ttnn.softmax(
      attn_scores_ttnn * self._scaler,
      dim=-1
    )

    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=32, x=32)
    )

    context_vec = ttnn.to_torch(context_vec_ttnn)

    return context_vec

In [16]:
torch.manual_seed(789)
torch.set_printoptions(sci_mode=False)

t.reset()

device_id = 0
device = ttnn.open_device(device_id=device_id)

ttnn.enable_program_cache(device)

batch = []
t.start()
sa_ttnn_v2 = SelfAttention_ttnn_opt_v2(2048, 2048, device)
for tensor in torch_tensors:
  result = sa_ttnn_v2(tensor)
  """
  batch.append(result)

  if len(batch) == 1000:
    for i in range(0, len(batch)):
      result = ttnn.to_torch(batch.pop(0))

    batch = []
  """

t.stop()

"""
for i in range(0, len(batch)):
  result = ttnn.to_torch(batch.pop(0))
"""

ttnn.close_device(device)

result, t.elapsed_ms()

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Enabling program cache on device 0
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[-0.0183, -0.0034,  0.0266,  ..., -0.0162, -0.0320,  0.0466],
              [-0.0081,  0.0055,  0.0337,  ..., -0.0035, -0.0113,  0.0249],
              [-0.0121,  0.0045,  0.0354,  ..., -0.0140, -0.0332,  0.0342],
              ...,
              [-0.0271, -0.0160,  0.0430,  ..., -0.0034, -0.0315,  0.0309],
              [-0.0278, -0.0022,  0.0369,  ..., -0.0125, -0.0265,  0.0276],
              [-0.0188,  0.0009,  0.0381,  ..., -0.0210, -0.0293,  0.0420]],
             dtype=torch.bfloat16),
 110242.26951599121)

In [136]:

ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


**Performance (ms)**

CPU is a Xeon w5-3535x 20c/40t
Tensix is on Wormhole n150d


|        | CPU      | Tensix    | Tensix Fast |
|--------|----------|-----------|-------------|
| 1      | 67.33    | 105.28    | 98.39       |
| 100    | 1258.74  | 4676.81   | 1183.98     |
| 1000   | 11671.33 | 46124.67  | 10935.41    |
| 10000  | 115895.56| 467358.59 | 109907.44   |


**Power Consumption (Watts)**

|         | CPU       | Tensix     | Creating Tensors |
|---------|-----------|------------|------------------|
| Baseline| 446       | 425        | 444              |
| Active  | 740       | 535        | 537              |
|         | 116655.43 | 110242.27  | 81634.55         |

Wow the CPU uses a lot more power, but we see a point here.. the value of the Wormhole is not that it is that much faster, but that for the performance, we gain incredible power efficiency. 
