# Simple Weighted Attention Mechanism V2

We continue on with implementing the Simple Weighted Attention Mechanism. We'll call this a "V2". It uses `torch` linear layers, the beginnings in demonstrating a neural network.

**Tenstorrent Hardware: Beyond GPUs**

While exploring an improved attention mechanism, we'll take advantage in diving deeper in learning and understanding the advantages, and disadvantages of Tenstorrent harware. At some point working through this notebook, you'll probably come to ask "Why use a TPU when GPUs are already powerful?", or maybe "How does a Wormhole n150d compare to high performance CPUs?" You'll find the discussions on these topics pretty interesting if you've only been exposed to GPU compujte. 

The true capabilities of specialized hardare like what Tenstorrent provides becomes apparent when you start to compare performance across different architectures. For those familiar with GPU acceleration, the performance benefits of TPUs like the Tensix processor in the Wormhole might seen incremental, or even regressive, but the differences become clear once we start to measure in different perspectives

As with the previous implementations of some of the LLM-related techniques and code, we will start by implementing a version using just `torch` to gain understanding. The code, like some previous notebooks is inspired by Sebastian Raschka's LLM From Scratch code.

## Import Libraries

Like before we will first start by importing the necessary libraries.

We also set the manual seed for `torch` to be `789`. Sebastian chooses this as the seed in his projects, so I will just use that.


In [2]:
import torch
from torch import nn

torch.manual_seed(789)

<torch._C.Generator at 0x7775b3fd49d0>

## SelfAttention_v2 Class

Next, we will implement an improved version of the self attention mechansim from the previous notebook. This time, we will generalize the implementation to a Python class called `SimpleAttention_v2`. 

Usage is simple, it is a `torch` module that simply does a forward pass of given inputs. For example, it is used like this:

```python
# x is a tensor to be treated as input
result = SelfAttention_v2(x)
```

In [3]:
class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    
    self.W_query = nn.Linear(d_in, d_out, bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1] ** 0.5,
      dim=-1
    )
    context_vec = attn_weights @ values

    return context_vec

Let's test this class in CPU by initializing the `context` to be treated as input. Thi s is the same tensor we have been using to test. We will then move to creating the `SelfAttention_v2` instance, and then evaluating by performing a "forward pass" as-is.

In [4]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d = context[0].shape
d_in = d[0]
d_out = d[0] - 1

sa_v2 = SelfAttention_v2(d_in, d_out)

context_vec = sa_v2(context)
context_vec

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

## Making it work with Tenstorrent (Naive Way)

Now it's time to rewrite the `torch` code to using `ttnn` so that it can be accelerated on the Tenstorrent hardware. The tricky part here is to figure out _what_ to port to using the Tenstorrent hardware. 

With my current skillset as of writing this notebook, I don't know how to do _everything_ from scratch in `ttnn` land. But I can at least try to accelerate some of the compute by offloading some tensor calculations within the `SelfAttention_v2` class. 

Let's write the code in a **naive** way. Assuming we just straight up do a translation of the lines...

Looking through again, I can already see that the `forward` method is a great candidate in accelerating operations by performing the calculations on my Wormhole.

Here's what we currently have:

```python
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
context_vec = attn_weights @ values
```

Here is a potential strategy:
* Transfer the pre-computed keys, queries, and values tensors to the Tenstorrent device using `ttnn.to_device`
* Using those tensors, we can call `ttnn.matmul` on them to perform accelerated matrix multiplication. 
* Return the context vector computed back to the host from the device 

Basically, the goal is to offload as much computation as possible to hardware. We have to minimize the amount of data transfers between CPU and device memory as that can destroy any benefits in computing on the Tenstorrent hardware.

In order to understand if this approach helps or hurts performance (CPU vs Wormhole), we're going to create a small benchmark with a lot of random tensors (of substantial shape) as input and perform a forward pass repeatedly through the every tensor in the collection to see the difference in CPU vs Tenstorrent hardware for this type of computation.

In [5]:
import ttnn

device_id = 0
device = ttnn.open_device(device_id=device_id)

2025-04-29 05:02:59.847 | DEBUG    | ttnn:<module>:83 - Initial ttnn.CONFIG:
Config{cache_path=/home/avgdev/.cache/ttnn,model_cache_path=/home/avgdev/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_should_raise_exception=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


                 Device | INFO     | Opening user mode device driver
[32m2025-04-29 05:02:59.952[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled

[32m2025-04-29 05:02:59.964[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-04-29 05:02:59.966[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Harvesting mask for chip 0 is 0x200 (physical layout: 0x1, logical: 0x200, simulated harvesting mask: 0x0).
[32m2025-04-29 05:02:59.967[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Opened PCI device 0; KMD version: 1.33.0, IOMMU: disabled
[32m2025-04-29 05:02:59.967[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected PCI devices: [0]
[32m2025-04-29 05:02:59.967[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using local chip ids: 

New chip! We now have 1 chips
Chip initialization complete (found )
Chip initializing complete...
 ARC

 [4/4] DRAM

 [16/16] ETH

 CPU

Chip detection complete (found )


Once the device is ready, we can write a line-by-line translate the forward pass of the self attention mechanism to `ttnn`.

Let's warm up with the first line:

```python
attn_scores = queries @ keys.T
```

This involves:
1. `torch` - Assigning `queries` to be the result of input with `W_query`.
2. `torch`- Transpose the `keys` (result of input with `W_key`) - we'll call this `keys_transposed`.
3. `ttnn` - Create the `ttnn` tensors from `queries` and `keys_transposed`.
4. `ttnn` - Compute the multiplication of tensors using `ttnn.matmul` of `queries` and `keys_transposed`.

We'll also compute the CPU version just to compare the resulting tensor compared to the tensor computed in Tenstorrent hardware. 

These results should roughly be the same with the differences just being in precision of data type.s


In [6]:

queries = sa_v2.W_query(context)
keys = sa_v2.W_key(context)
keys_transposed = keys.T

queries_ttnn = ttnn.from_torch(
  queries,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

keys_transposed_ttnn = ttnn.from_torch(
  keys_transposed,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

attn_scores_ttnn = ttnn.matmul(
  queries_ttnn,
  keys_transposed_ttnn
)

# Compute the CPU version for comparison
attn_scores_cpu = queries @ keys.T 
 
attn_scores_ttnn, attn_scores_cpu



(ttnn.Tensor([[ 0.28906,  0.07178,  ...,  0.13379, -0.04980],
              [ 0.46484,  0.17090,  ...,  0.17676,  0.00867],
              ...,
              [ 0.21777,  0.08691,  ...,  0.07812,  0.01489],
              [ 0.34180,  0.12598,  ...,  0.12988,  0.00720]], shape=Shape([6, 6]), dtype=DataType::BFLOAT16, layout=Layout::TILE),
 tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
         [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
         [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
         [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
         [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
         [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
        grad_fn=<MmBackward0>))

The next step can be pretty tricky. We will need to perform a softmax operation on the attention scores scaled from the previous cell. Notice that we _can't do_ a scalar divide against a tensor in `ttnn`.

```python
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
```

Here is what needs to happen

1. We will need to scale ahead of time in `torch`, the result is a tensor of the attention scores divided by scale value. This means that we have to bring the `attn_scores_ttnn` we have just computed back to CPU memory to do that. 
2. Then we send the result of the scaled tensor back to hardware.
3. Using the scaled attention scores in our hardware, we directly invoke `ttnn.softmax` on the last dimension of the tensor.

Again, we should get pretty close to the CPU version. (Also computed in the cell)

In [7]:
attn_scores_scaled_torch = ttnn.to_torch(attn_scores_ttnn) / keys.shape[-1] ** 0.5
attn_scores_scaled_ttnn = ttnn.from_torch(
  attn_scores_scaled_torch,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

attn_weights_ttnn = ttnn.softmax(attn_scores_scaled_ttnn, dim=-1)

# CPU version
attn_weights = torch.softmax(attn_scores_scaled_torch, dtype=torch.float32, dim=-1)

attn_weights_ttnn, attn_weights



(ttnn.Tensor([[ 0.19629,  0.16211,  ...,  0.17188,  0.15039],
              [ 0.20898,  0.16504,  ...,  0.16602,  0.14355],
              ...,
              [ 0.18652,  0.16602,  ...,  0.16309,  0.15625],
              [ 0.19922,  0.16602,  ...,  0.16602,  0.14941]], shape=Shape([6, 6]), dtype=DataType::BFLOAT16, layout=Layout::TILE),
 TorchTensor([[0.1920, 0.1647, 0.1652, 0.1550, 0.1721, 0.1511],
              [0.2040, 0.1658, 0.1663, 0.1496, 0.1665, 0.1478],
              [0.2035, 0.1659, 0.1662, 0.1499, 0.1662, 0.1482],
              [0.1867, 0.1667, 0.1668, 0.1571, 0.1661, 0.1565],
              [0.1830, 0.1668, 0.1670, 0.1588, 0.1658, 0.1585],
              [0.1936, 0.1662, 0.1665, 0.1542, 0.1667, 0.1529]]))

Finally, we can compute the context vector. This involves another `ttnn.matmul` on the attention weights found through performing `softmax` with the scaled attention scores and the `values` in hardware.

```python
context_vec = attn_weights @ values
```

The CPU version is also computed to compare results. 

In [8]:
values = sa_v2.W_value(context)

values_ttnn = ttnn.from_torch(
  values,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device
)

context_vec_ttnn = ttnn.matmul(attn_weights_ttnn, values_ttnn)

# Compute a CPU version.
context_vec_cpu = attn_weights @ values

context_vec_ttnn, context_vec_cpu

(ttnn.Tensor([[-0.07373,  0.07080],
              [-0.07373,  0.07080],
              ...,
              [-0.07617,  0.06689],
              [-0.07471,  0.06982]], shape=Shape([6, 2]), dtype=DataType::BFLOAT16, layout=Layout::TILE),
 TorchTensor([[-0.0739,  0.0712],
              [-0.0748,  0.0703],
              [-0.0749,  0.0702],
              [-0.0760,  0.0684],
              [-0.0763,  0.0679],
              [-0.0754,  0.0693]], grad_fn=<AliasBackward0>))

**Dont forget!!!**

Close the device so that resources are freed.

In [9]:
ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


## Putting It Together (SelfAttention_v2 in TTNN Naive Way)

The above was a lot more verbose, but let's try to put it all together into a revised `SelfAttention_ttnn_v2` class. It is "Tenstorrent optimized". We will just need to ensure that this class also includes an additional `device` paramter in the constructor so that we can internally move `ttnn` tensors to and from device as well as performing the computations.

**Note** The `forward` method will return back the `context_vec` on CPU device storage as opposed to staying in Tenstorrent hardware. Hence the final `to_torch` call on it... This is done on purpose to maintain some API compatibility and expectations. 

In [10]:
import torch
import ttnn

torch.manual_seed(789)

class SelfAttention_ttnn_v2(nn.Module):
  def __init__(self, d_in, d_out, device):
    super().__init__()
    
    self.W_query = nn.Linear(d_in, d_out, bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

    self._device = device

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    queries_ttnn = ttnn.from_torch(
      queries,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )

    keys_transposed_ttnn = ttnn.from_torch(
      keys.T,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )

    values_ttnn = ttnn.from_torch(
      values,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn,
      keys_transposed_ttnn
    )

    attn_scores_scaled_torch = ttnn.to_torch(attn_scores_ttnn) / (keys.shape[-1] ** 0.5)
    attn_scores_scaled_ttnn = ttnn.from_torch(
      attn_scores_scaled_torch,
      dtype=ttnn.bfloat16,
      layout=ttnn.TILE_LAYOUT,
      device=self._device
    )
    attn_weights_ttnn = ttnn.softmax(attn_scores_scaled_ttnn, dim=-1)

    context_vec_ttnn = ttnn.matmul(attn_weights_ttnn, values_ttnn)

    context_vec = ttnn.to_torch(context_vec_ttnn)

    return context_vec

Testing this again with the same `context` ad instantiating the `SelfAttention_ttnn_v2` instance with the correct parameters, we will get pretty close to the result we found by just doing a forward pass with `SelfAttention_v2` earlier on CPU. 

In [11]:
context = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d = context[0].shape
d_in = d[0]
d_out = d[0] - 1

device_id = 0
device = ttnn.open_device(device_id=device_id)

sa_ttnn_v2 = SelfAttention_ttnn_v2(d_in, d_out, device)
context_vec = sa_ttnn_v2(context)

ttnn.close_device(device)

context_vec

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


TorchTensor([[-0.0737,  0.0708],
             [-0.0737,  0.0708],
             [-0.0737,  0.0708],
             [-0.0757,  0.0684],
             [-0.0762,  0.0669],
             [-0.0747,  0.0698]], dtype=torch.bfloat16)

## Benchmarking the Implementations

It's time to compare the two implementations. CPU vs Tenstorrent... Who will win?

The CPU I will be using is an **Intel Xeon w5-3535x**. It has 20 cores and 40 threads. You can look at more specifications in Intel's Ark. But the main point I want to make is that it has a 2.9 GHz base clock and can boost to 4.8 GHz all core. On top of that is that it has AMX/AVX-512 acceleration for a lot of the matrix operations we do for AI. 

So the Tenstorrent hardware has really tough competition.


## Prepare the Environment

Let's learn how to write some code to benchmark the two implementations. We'll use the `time` module's `time` method to return back the current time in seconds since the epoch. The pattern is this:

1. Measure start time by calling `time.time()`. 
2. Do the thing you want to measure.
3. Measure end time by calling `time.time()` again.
4. Print the difference between the end time and start time.

By default the difference will be in seconds, but we can multiply by 1000 to get the milliseconds. Generally, I think seeing time in milliseconds is more helpful, so I tend to just always do that.

In [12]:
import time

start = time.time()

print("hello")

end = time.time()

print(f"Time: {(end - start) * 1000} milliseconds")

hello
Time: 0.06747245788574219 milliseconds


Now, let's generalize this into a  class called `PerfTimer`.

Here's how we use it:

```python
t = PerfTimer()
t.start()
do_some_work()
t.stop()

print(f"Time: {t.elapsed_ms()} milliseconds")
```

In [13]:
import time

class PerfTimer:
  def __init__(self):
    self.start_time = 0
    self.end_time = 0
    
  def start(self):
    self.start_time = time.time()
    
  def stop(self):
    self.end_time = time.time()

  def reset(self):
    self.start_time = 0
    self.end_time = 0
    
  def elapsed_ms(self):
    return (self.end_time - self.start_time) * 1000

## Let's Benchmark!

We time the creation of 1,000 random tensors of shape (1024, 2048) and stack them into a single tensor.

Record the time taken in milliseconds to do this. This is highly dependent on CPU. A faster CPU will finish this work faster.

In [14]:
torch.manual_seed(789)

t = PerfTimer()

t.start()
torch_tensors = torch.stack([torch.randn(1024, 2048) for _ in range(0, 1000)])
t.stop()

torch_tensors.shape, t.elapsed_ms()

(torch.Size([1000, 1024, 2048]), 8200.341939926147)

## SelfAttention_v2 on CPU!

In [15]:
torch.manual_seed(789)
torch.set_printoptions(sci_mode=False)

t.reset()

t.start()
sa_v2 = SelfAttention_v2(2048, 2048)
for tensor in torch_tensors:
  result = sa_v2(tensor)
t.stop()

result, t.elapsed_ms()

(tensor([[ 0.0142, -0.0142, -0.0185,  ...,  0.0308, -0.0289,  0.0347],
         [ 0.0125, -0.0044, -0.0308,  ...,  0.0253, -0.0311,  0.0227],
         [ 0.0020, -0.0132, -0.0349,  ...,  0.0241, -0.0279,  0.0237],
         ...,
         [ 0.0190, -0.0007, -0.0203,  ...,  0.0238, -0.0351,  0.0171],
         [ 0.0146,  0.0041, -0.0246,  ...,  0.0209, -0.0370,  0.0034],
         [ 0.0114, -0.0088, -0.0332,  ...,  0.0046, -0.0331,  0.0101]],
        grad_fn=<MmBackward0>),
 11802.255153656006)

## SelfAttention_v2 on TTNN!



In [16]:
torch.manual_seed(789)
torch.set_printoptions(sci_mode=False)

t.reset()

device_id = 0
device = ttnn.open_device(device_id=device_id)

t.start()
sa_ttnn_v2 = SelfAttention_ttnn_v2(2048, 2048, device)
for tensor in torch_tensors:
  result = sa_ttnn_v2(tensor)

t.stop()

ttnn.close_device(device)

result, t.elapsed_ms()


                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[ 0.0143, -0.0148, -0.0179,  ...,  0.0312, -0.0295,  0.0349],
              [ 0.0137, -0.0045, -0.0320,  ...,  0.0259, -0.0315,  0.0243],
              [ 0.0019, -0.0135, -0.0349,  ...,  0.0245, -0.0273,  0.0239],
              ...,
              [ 0.0194, -0.0003, -0.0203,  ...,  0.0242, -0.0356,  0.0179],
              [ 0.0150,  0.0042, -0.0242,  ...,  0.0210, -0.0369,  0.0032],
              [ 0.0115, -0.0100, -0.0334,  ...,  0.0045, -0.0330,  0.0107]],
             dtype=torch.bfloat16),
 50487.6024723053)

You may have noticed that our `ttnn` implementation can be slower, depending on the CPU you have. 

For me, my Xeon w5-3535x is _faster_ using regular `torch` than `ttnn`! Did I just spend $1000 on hardware when it wasn't even necessary? Well, let's continue on...

## SelfAttention_v2 TTNN Optimized

We can do better!

In summary, the naive method is naive is in that a direct translation from `torch` to `ttnn` code isn't optimally using the Tenstorrent Wormhole capabilities. Here are some current issues that we have with the naive implementation that we can fix:

1. Our `nn.Linear` modules are on CPU memory and performing compute happens on there. These `nn.Linear` modules are tensors which can be offloaded to the TPU.
2. We move back and forth between CPU memory and TPU memory too often and unnecessarily
3. We do not use the TPU's multiprocessing capabilities when performing matrix multiplication
4. We do not use the TPU's memory model very well when performing the matrix multiplication

Let's address them each one by one.

### Memory Config

Each tensor in TTNN can have a specific memory configuration. What does this mean? First, we need to understand [tensor sharding](https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/tensor.html#tensor-sharding).

When we perform tensor sharding, we can split a tensor across different L1 memory space on _different_ hardware cores. In other words, split the problem space to individual cores for multi-processing, but also include the data in fast L1 memory.

We can do this by using the `memory_config` [memory config](https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/tensor.html#memory-config) parameter when creating a `ttnn.Tensor`. To use the L1 memory, we can set this to `ttnn.L1_MEMORY_CONFIG`. 

Here is an example in creating a 256x128 tensor in L1 memory:


In [None]:
device_id = 0
device = ttnn.open_device(device_id=device_id)

test = ttnn.from_torch(
  torch.randn(256, 128),
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=device,
  memory_config=ttnn.L1_MEMORY_CONFIG
)
test_cpu = ttnn.to_torch(t)

ttnn.close_device(device)

test_cpu

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


TorchTensor([[    -0.6797,      0.3418,      0.4844,  ...,     -0.5430,
                   0.4824,      1.2656],
             [     1.2422,     -0.6562,     -2.0156,  ...,     -1.9375,
                  -0.5938,     -0.5742],
             [    -0.2754,     -0.7461,      0.4902,  ...,     -0.5508,
                   1.8516,      0.0015],
             ...,
             [    -1.1250,     -0.1865,     -1.2188,  ...,     -0.6602,
                   0.2832,      0.2539],
             [     0.8867,      0.7305,      1.8438,  ...,     -0.0850,
                   0.9570,      0.4688],
             [     0.4648,     -0.0060,     -0.2695,  ...,     -0.8477,
                   0.8359,      0.6797]], dtype=torch.bfloat16)

For most of our tensors we want to optimize where we access them repeatedly, we can store them in L1 memory.

### Using Weights in Tenstorrent Hardware

Currently in each forward pass, we compute the linear transformations of the `queries`, `keys`, and `values` tensors using `torch` before moving those results to the Wormhole (creating the `ttnn.Tensor` equivalent and then moving to the device!). If our dataset is large enough, this ends chewing up time.

```python
# These are on CPU
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

# After performing the linear transformation, we convert the tensors to
# ttnn tensors, and then move them to the device.
queries_ttnn = ttnn.from_torch(
  queries,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=self._device
)

keys_transposed_ttnn = ttnn.from_torch(
  keys.T,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=self._device
)

values_ttnn = ttnn.from_torch(
  values,
  dtype=ttnn.bfloat16,
  layout=ttnn.TILE_LAYOUT,
  device=self._device
)
```

You can now see that this gets very expensive.

Instead, an optimization can be to pre-compute and move the `W_query`, `W_key`, and `W_value` tensors to the device when the self attention module is initialized. We can also store them in L1 memory too. For example, in the constructor:

```python
self.W_query = nn.Linear(d_in, d_out, bias=False)
self.W_key = nn.Linear(d_in, d_out, bias=False)
self.W_value = nn.Linear(d_in, d_out, bias=False)

self._device = device

# Extract weight matrices from PyTorch layers and convert to TTNN once
self.W_query_ttnn = ttnn.from_torch(
  self.W_query.weight, 
  dtype=ttnn.bfloat16, 
  layout=ttnn.TILE_LAYOUT, 
  device=self._device,
  memory_config=ttnn.L1_MEMORY_CONFIG
)

self.W_key_ttnn = ttnn.from_torch(
  self.W_key.weight, 
  dtype=ttnn.bfloat16, 
  layout=ttnn.TILE_LAYOUT, 
  device=self._device,
  memory_config=ttnn.L1_MEMORY_CONFIG
)

self.W_value_ttnn = ttnn.from_torch(
  self.W_value.weight, 
  dtype=ttnn.bfloat16, 
  layout=ttnn.TILE_LAYOUT, 
  device=self._device,
  memory_config=ttnn.L1_MEMORY_CONFIG
)
```


### Linear Transformation in Hardware

Since now we have these weights in hardware, and sitting in L1 memory, they are ready to be used with an input to perform linear transformation whenever `forward` is called. We will just need to also move the input to the device once we receive it.

```python
x_ttnn = ttnn.from_torch(
  x, 
  dtype=ttnn.bfloat16, 
  layout=ttnn.TILE_LAYOUT, 
  device=device,
)
```

Instead of performing the transformation on CPU, we can use `ttnn.linear` to do the same type of operations in the Tenstorrent hardware. We are able to do this directly because both the input `x_ttnn` and the weights are in the device.

`ttnn.linear` is just a specialized matrix multiplication, we can also shard them across the cores using a `CoreGrid` configuration. In this case, we can use a 32x32 grid. 

```python
queries_ttnn = ttnn.linear(
  x_ttnn,
  self.W_query_ttnn,
  transpose_b=True,
  core_grid=ttnn.CoreGrid(y=8, x=8)
)
values_ttnn = ttnn.linear(
  x_ttnn,
  self.W_value_ttnn,
  transpose_b=True,
  core_grid=ttnn.CoreGrid(y=8, x=8)
)
keys_ttnn = ttnn.linear(
  x_ttnn,
  self.W_key_ttnn,
  transpose_b=True,
  core_grid=ttnn.CoreGrid(y=8, x=8)
)
```


### CoreGrid Configuration

How did I come up with 8x8 `CoreGrid` configuration? I just ran multiple tests in different configurations. I found that `8x8` was most optimal. Anything more didn't result in faster compute in milliseconds to process the 1000 tensors.

| CoreGrid | Time     |
|----------|----------|
| 2x2      | 19399.75 |
| 8x8      | 10988.9  |
| 16x16    | 11042.35 |
| 32x32    | 11005.19 |

### Attention Weights and Context Vector

The rest of the code refactors and optimizes the attention scores, weights and context vector calculations. We do everything in hardware and even multiplying the `attn_scores_ttnn` tensor with a scalar and then performing the `softmax` operation to get the `attn_weights_ttnn`.

Notice that it is still necessary to send the `context_vec` back to the host. It is done on purpose to maintain the compatibility. We can get additional speed up if we didn't have to do this. 


```python
attn_scores_ttnn = ttnn.matmul(
  queries_ttnn, 
  ttnn.permute(keys_ttnn, (1, 0)),
  core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
)

attn_weights_ttnn = ttnn.softmax(
  attn_scores_ttnn * self._scaler,
  dim=-1
)

context_vec_ttnn = ttnn.matmul(
  attn_weights_ttnn,
  values_ttnn,
  core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
)

context_vec = ttnn.to_torch(context_vec_ttnn)

return context_vec
```

### Optimized Implementation
Putting it all together, we can come up with a new class, `SelfAttention_ttnn_opt_v2` which combines these different optimizations.

In [35]:
import torch
import ttnn

torch.manual_seed(789)

core_grid_x = 8
core_grid_y = 8

class SelfAttention_ttnn_opt_v2(nn.Module):
  def __init__(self, d_in, d_out, device):
    super().__init__()

    self.W_query = nn.Linear(d_in, d_out , bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

    self._device = device
    
    # Extract weight matrices from PyTorch layers and convert to TTNN once
    self.W_query_ttnn = ttnn.from_torch(
      self.W_query.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_key_ttnn = ttnn.from_torch(
      self.W_key.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )
    
    self.W_value_ttnn = ttnn.from_torch(
      self.W_value.weight, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=self._device,
      memory_config=ttnn.L1_MEMORY_CONFIG
    )

    self._scaler = 1 / (d_out ** 0.5)

  def forward(self, x):
    x_ttnn = ttnn.from_torch(
      x, 
      dtype=ttnn.bfloat16, 
      layout=ttnn.TILE_LAYOUT, 
      device=device,
    )
    queries_ttnn = ttnn.linear(
      x_ttnn,
      self.W_query_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    values_ttnn = ttnn.linear(
      x_ttnn,
      self.W_value_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )
    keys_ttnn = ttnn.linear(
      x_ttnn,
      self.W_key_ttnn,
      transpose_b=True,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    attn_scores_ttnn = ttnn.matmul(
      queries_ttnn, 
      ttnn.permute(keys_ttnn, (1, 0)),
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    attn_weights_ttnn = ttnn.softmax(
      attn_scores_ttnn * self._scaler,
      dim=-1
    )

    context_vec_ttnn = ttnn.matmul(
      attn_weights_ttnn,
      values_ttnn,
      core_grid=ttnn.CoreGrid(y=core_grid_y, x=core_grid_x)
    )

    context_vec = ttnn.to_torch(context_vec_ttnn)

    return context_vec

The benchmark code is more or less the same

In [36]:
torch.manual_seed(789)
torch.set_printoptions(sci_mode=False)

t = PerfTimer()

device_id = 0
device = ttnn.open_device(device_id=device_id)

ttnn.enable_program_cache(device)

t.start()
sa_ttnn_v2 = SelfAttention_ttnn_opt_v2(2048, 2048, device)
for tensor in torch_tensors:
  result = sa_ttnn_v2(tensor)

t.stop()

ttnn.close_device(device)

result, t.elapsed_ms()

                  Metal | INFO     | Initializing device 0. Program cache is NOT enabled
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Metal | INFO     | Enabling program cache on device 0
                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


(TorchTensor([[     0.0146,     -0.0121,     -0.0173,  ...,      0.0306,
                   -0.0269,      0.0342],
              [     0.0140,     -0.0034,     -0.0291,  ...,      0.0260,
                   -0.0284,      0.0236],
              [     0.0041,     -0.0111,     -0.0325,  ...,      0.0251,
                   -0.0256,      0.0238],
              ...,
              [     0.0198,      0.0000,     -0.0197,  ...,      0.0242,
                   -0.0334,      0.0179],
              [     0.0156,      0.0047,     -0.0231,  ...,      0.0215,
                   -0.0339,      0.0051],
              [     0.0125,     -0.0075,     -0.0310,  ...,      0.0067,
                   -0.0305,      0.0119]], dtype=torch.bfloat16),
 11139.931440353394)

Now, let's close the device :) 

In [136]:

ttnn.close_device(device)

                  Metal | INFO     | Closing device 0
                  Metal | INFO     | Disabling and clearing program cache on device 0


You will notice that the optimized version of our code is significantly faster, and is now faster than the CPU (in my case a w5-3535x). 

## Performance Analysis - w5-3535x vs Wormhole n150d

Let's talk about what we have just observed...

**Performance (ms)**

CPU is a Xeon w5-3535x 20c/40t
Tensix is on Wormhole n150d


|        | CPU      | Tensix    | Tensix Fast |
|--------|----------|-----------|-------------|
| 1      | 67.33    | 105.28    | 98.39       |
| 100    | 1258.74  | 4676.81   | 1183.98     |
| 1000   | 11671.33 | 46124.67  | 10935.41    |
| 10000  | 115895.56| 467358.59 | 109907.44   |


**Power Consumption (Watts)**

|         | CPU       | Tensix     | Creating Tensors |
|---------|-----------|------------|------------------|
| Baseline| 446       | 425        | 444              |
| Active  | 740       | 535        | 537              |
|         | 116655.43 | 110242.27  | 81634.55         |

The baseline is high because I am using a Kill-a-Watt shared with other devices that are being charged (they are idle). But the delta purely comes from the main machine running the tests. And you can see... the CPU uses a lot more power! 

We see a point here.. the value of the Wormhole is not that it is that much faster, but that for the performance, we gain incredible power efficiency. This is why you would want to use a Wormhole in not only development, but in deployments within the cloud.


## Performance Analysis - i7 13700 vs Wormhole n150d

Okay, it is a bit unfair to compare a $1700 CPU with really good hardware acceleration for matrix/vector operations and native BF16 support. How about a CPU that is more mainstream? 

Let's use a Core i7 13700. This is a CPU that is fairly recent and is likely to be used in normal development machines. Even in older generations of Xeon processor (Ice Lake), there will be a difference. But how does our Wormhole do against an i7 13700? Let's do the same benchmarks.

Test Machine;
- Core i7 13700
- 64 GB DDR5 5600 RAM
- Tenstorrent Wormhold n150d


## Takeaway

Here is my takeaway. When I had initially